15 #ifndef PAAL_DREYFUS_WAGNER_HPP
16 #define PAAL_DREYFUS_WAGNER_HPP
21 #include <unordered_map>
22 #include <unordered_set>
31 template <
typename Metric,
typename Terminals,
typename NonTerminals,
32 unsigned int TerminalsLimit = 32>
36 using Vertex =
typename MT::VertexType;
38 using Edge =
typename std::pair<Vertex, Vertex>;
39 using TerminalsBitSet =
typename std::bitset<TerminalsLimit>;
40 using State = std::pair<Vertex, TerminalsBitSet>;
41 using steiner_elements = std::unordered_set<Vertex, boost::hash<Vertex>>;
47 const NonTerminals &non_terminals)
48 : m_cost_map(cost_map), m_terminals(term),
49 m_non_terminals(non_terminals) {
51 assert(m_terminals.size() <= TerminalsLimit);
52 for (
int i = 0; i < (int)m_terminals.size(); i++) {
53 m_elements_map[m_terminals[i]] = i;
62 int n = m_elements_map.size();
63 assert(start >= 0 && start < n);
64 TerminalsBitSet remaining;
66 for (
int i = 0; i < n; i++) {
69 remaining.reset(start);
71 m_cost = connect_vertex(m_terminals[start], remaining);
72 retrieve_solution_connect(m_terminals[start], remaining);
83 const std::vector<Edge> &
get_edges()
const {
return m_edges; }
89 return m_steiner_elements;
99 Dist connect_vertex(Vertex v, TerminalsBitSet remaining) {
100 if (remaining.none()) {
103 if (remaining.count() == 1) {
104 int k = smallest_bit(remaining);
105 Dist cost = m_cost_map(v, m_terminals[k]);
106 m_best_cand[code_state(v, remaining)] =
107 std::make_pair(cost, m_terminals[k]);
111 auto iter = m_best_cand.find(code_state(v, remaining));
112 if (iter != m_best_cand.end()) {
113 return iter->second.first;
115 Dist
best = split_vertex(v, remaining);
118 auto try_vertex = [&](Vertex w) {
119 Dist val = split_vertex(w, remaining);
120 val += m_cost_map(v, w);
121 if (best < 0 || val < best) {
126 for (Vertex w : m_non_terminals) {
129 for (
auto w_with_id : m_elements_map) {
130 if (!remaining.test(w_with_id.second)) {
131 try_vertex(w_with_id.first);
134 for (
auto vertex_and_terminal_id : m_elements_map) {
135 Vertex w = vertex_and_terminal_id.first;
136 int terminal_id = vertex_and_terminal_id.second;
137 if (!remaining.test(terminal_id))
continue;
138 remaining.reset(terminal_id);
139 Dist val = connect_vertex(w, remaining);
140 val += m_cost_map(v, w);
141 remaining.set(terminal_id);
143 if (best < 0 || val < best) {
148 m_best_cand[code_state(v, remaining)] = std::make_pair(best, cand);
155 Dist split_vertex(Vertex v, TerminalsBitSet remaining) {
156 if (remaining.count() < 2) {
160 auto iter = m_best_split.find(code_state(v, remaining));
161 if (iter != m_best_split.end()) {
162 return iter->second.first;
164 int k = smallest_bit(remaining) +
166 std::pair<Dist, TerminalsBitSet> best =
167 best_split(v, remaining, remaining, k);
168 m_best_split[code_state(v, remaining)] =
best;
175 std::pair<Dist, TerminalsBitSet> best_split(
const Vertex v,
176 const TerminalsBitSet remaining,
177 TerminalsBitSet subset,
int k) {
178 if (k == (
int)m_terminals.size()) {
179 TerminalsBitSet complement = remaining ^ subset;
180 if (!subset.none() && !complement.none()) {
182 connect_vertex(v, subset) + connect_vertex(v, complement);
183 return make_pair(val, subset);
185 return std::make_pair(-1, NULL);
188 std::pair<Dist, TerminalsBitSet> ret1, ret2;
189 ret1 = best_split(v, remaining, subset, k + 1);
190 if (remaining.test(k)) {
192 ret2 = best_split(v, remaining, subset, k + 1);
193 if (ret1.first < 0 || ret1.first > ret2.first) {
205 void retrieve_solution_connect(Vertex v, TerminalsBitSet remaining) {
206 if (remaining.none())
return;
207 Vertex next = m_best_cand.at(code_state(v, remaining)).second;
209 auto terminal_id_iter = m_elements_map.find(next);
211 retrieve_solution_split(next, remaining);
212 }
else if (terminal_id_iter == m_elements_map.end()
213 || !remaining.test(terminal_id_iter->second)) {
214 add_vertex_to_graph(next);
215 add_edge_to_graph(v, next);
216 retrieve_solution_split(next, remaining);
218 add_edge_to_graph(v, next);
219 remaining.flip(terminal_id_iter->second);
220 retrieve_solution_connect(next, remaining);
228 void retrieve_solution_split(Vertex v, TerminalsBitSet remaining) {
229 if (remaining.none())
return;
230 TerminalsBitSet split =
231 m_best_split.at(code_state(v, remaining)).second;
232 retrieve_solution_connect(v, split);
233 retrieve_solution_connect(v, remaining ^ split);
239 State code_state(Vertex v, TerminalsBitSet remaining) {
241 return std::make_pair(v, remaining);
248 std::size_t operator()(
const State &k)
const {
249 return boost::hash<Vertex>()(k.first) ^
250 (std::hash<TerminalsBitSet>()(k.second) << 1);
257 void add_edge_to_graph(Vertex u, Vertex w) {
258 Edge e = std::make_pair(u, w);
259 m_edges.push_back(e);
265 void add_vertex_to_graph(Vertex v) { m_steiner_elements.insert(v); }
270 int smallest_bit(TerminalsBitSet mask) {
272 while (!mask.test(k)) ++k;
276 const Metric &m_cost_map;
278 const NonTerminals &m_non_terminals;
281 steiner_elements m_steiner_elements;
283 std::vector<Edge> m_edges;
285 std::unordered_map<Vertex, int, boost::hash<Vertex>> m_elements_map;
287 using StateV = std::pair<Dist, Vertex>;
288 using StateBM = std::pair<Dist, TerminalsBitSet>;
289 std::unordered_map<State, StateV, state_hash> m_best_cand;
293 std::unordered_map<State, StateBM, state_hash> m_best_split;
307 template <
unsigned int TerminalsLimit = 32,
typename Metric,
typename Terminals,
typename NonTerminals>
308 dreyfus_wagner<Metric, Terminals, NonTerminals, TerminalsLimit>
310 const NonTerminals &non_terminals) {
312 metric, terminals, non_terminals);
317 #endif // PAAL_DREYFUS_WAGNER_HPP
dreyfus_wagner< Metric, Terminals, NonTerminals, TerminalsLimit > make_dreyfus_wagner(const Metric &metric, const Terminals &terminals, const NonTerminals &non_terminals)
Creates a dreyfus_wagner object.
bool best(Solution &solution, ContinueOnSuccess on_success, components...comps)
This local search chooses the best possible move and applies it to the solution. Note that this strat...
Terminals
enum indicates if given color represents terminal or NONTERMINAL.
const steiner_elements & get_steiner_elements() const
dreyfus_wagner(const Metric &cost_map, const Terminals &term, const NonTerminals &non_terminals)
const std::vector< Edge > & get_edges() const
puretype(std::declval< Metric >()(std::declval< VertexType >(), std::declval< VertexType >())) DistanceType
Distance type.