All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends Macros Pages
dreyfus_wagner.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c)
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See
5 // accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt)
7 //=======================================================================
15 #ifndef PAAL_DREYFUS_WAGNER_HPP
16 #define PAAL_DREYFUS_WAGNER_HPP
17 
20 
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <bitset>
24 
25 namespace paal {
26 
31 template <typename Metric, typename Terminals, typename NonTerminals,
32  unsigned int TerminalsLimit = 32>
34  public:
36  using Vertex = typename MT::VertexType;
37  using Dist = typename MT::DistanceType;
38  using Edge = typename std::pair<Vertex, Vertex>;
39  using TerminalsBitSet = typename std::bitset<TerminalsLimit>;
40  using State = std::pair<Vertex, TerminalsBitSet>;
41  using steiner_elements = std::unordered_set<Vertex, boost::hash<Vertex>>;
42 
46  dreyfus_wagner(const Metric &cost_map, const Terminals &term,
47  const NonTerminals &non_terminals)
48  : m_cost_map(cost_map), m_terminals(term),
49  m_non_terminals(non_terminals) {
50 
51  assert(m_terminals.size() <= TerminalsLimit);
52  for (int i = 0; i < (int)m_terminals.size(); i++) {
53  m_elements_map[m_terminals[i]] = i;
54  }
55  }
56 
61  void solve(int start = 0) {
62  int n = m_elements_map.size();
63  assert(start >= 0 && start < n);
64  TerminalsBitSet remaining;
65  // set all terminals except 'start' to 1
66  for (int i = 0; i < n; i++) {
67  remaining.set(i);
68  }
69  remaining.reset(start);
70 
71  m_cost = connect_vertex(m_terminals[start], remaining);
72  retrieve_solution_connect(m_terminals[start], remaining);
73  }
74 
78  Dist get_cost() const { return m_cost; }
79 
83  const std::vector<Edge> &get_edges() const { return m_edges; }
84 
88  const steiner_elements &get_steiner_elements() const {
89  return m_steiner_elements;
90  }
91 
92  private:
93  /*
94  * @brief Computes minimal cost of connecting given vertex and a set of
95  * other vertices.
96  * @param v vertex currently processed
97  * @param mask vertices not yet processed has corresponding bits set to 1
98  */
99  Dist connect_vertex(Vertex v, TerminalsBitSet remaining) {
100  if (remaining.none()) {
101  return 0;
102  }
103  if (remaining.count() == 1) {
104  int k = smallest_bit(remaining);
105  Dist cost = m_cost_map(v, m_terminals[k]);
106  m_best_cand[code_state(v, remaining)] =
107  std::make_pair(cost, m_terminals[k]);
108  return cost;
109  }
110  // Check in the map if already computed
111  auto iter = m_best_cand.find(code_state(v, remaining));
112  if (iter != m_best_cand.end()) {
113  return iter->second.first;
114  }
115  Dist best = split_vertex(v, remaining);
116  Vertex cand = v;
117 
118  auto try_vertex = [&](Vertex w) {
119  Dist val = split_vertex(w, remaining);
120  val += m_cost_map(v, w);
121  if (best < 0 || val < best) {
122  best = val;
123  cand = w;
124  }
125  };
126  for (Vertex w : m_non_terminals) {
127  try_vertex(w);
128  }
129  for (auto w_with_id : m_elements_map) {
130  if (!remaining.test(w_with_id.second)) {
131  try_vertex(w_with_id.first);
132  }
133  }
134  for (auto vertex_and_terminal_id : m_elements_map) {
135  Vertex w = vertex_and_terminal_id.first;
136  int terminal_id = vertex_and_terminal_id.second;
137  if (!remaining.test(terminal_id)) continue;
138  remaining.reset(terminal_id);
139  Dist val = connect_vertex(w, remaining);
140  val += m_cost_map(v, w);
141  remaining.set(terminal_id);
142 
143  if (best < 0 || val < best) {
144  best = val;
145  cand = w;
146  }
147  }
148  m_best_cand[code_state(v, remaining)] = std::make_pair(best, cand);
149  return best;
150  }
151 
155  Dist split_vertex(Vertex v, TerminalsBitSet remaining) {
156  if (remaining.count() < 2) {
157  return 0;
158  }
159  // Check in the map if already computed
160  auto iter = m_best_split.find(code_state(v, remaining));
161  if (iter != m_best_split.end()) {
162  return iter->second.first;
163  }
164  int k = smallest_bit(remaining) +
165  1; // optimalization, to avoid checking subset twice
166  std::pair<Dist, TerminalsBitSet> best =
167  best_split(v, remaining, remaining, k);
168  m_best_split[code_state(v, remaining)] = best;
169  return best.first;
170  }
171 
175  std::pair<Dist, TerminalsBitSet> best_split(const Vertex v,
176  const TerminalsBitSet remaining,
177  TerminalsBitSet subset, int k) {
178  if (k == (int)m_terminals.size()) {
179  TerminalsBitSet complement = remaining ^ subset;
180  if (!subset.none() && !complement.none()) {
181  Dist val =
182  connect_vertex(v, subset) + connect_vertex(v, complement);
183  return make_pair(val, subset);
184  } else {
185  return std::make_pair(-1, NULL);
186  }
187  } else {
188  std::pair<Dist, TerminalsBitSet> ret1, ret2;
189  ret1 = best_split(v, remaining, subset, k + 1);
190  if (remaining.test(k)) {
191  subset.flip(k);
192  ret2 = best_split(v, remaining, subset, k + 1);
193  if (ret1.first < 0 || ret1.first > ret2.first) {
194  ret1 = ret2;
195  }
196  }
197  return ret1;
198  }
199  }
200 
205  void retrieve_solution_connect(Vertex v, TerminalsBitSet remaining) {
206  if (remaining.none()) return;
207  Vertex next = m_best_cand.at(code_state(v, remaining)).second;
208 
209  auto terminal_id_iter = m_elements_map.find(next);
210  if (v == next) { // called wagner directly from dreyfus
211  retrieve_solution_split(next, remaining);
212  } else if (terminal_id_iter == m_elements_map.end() // nonterminal
213  || !remaining.test(terminal_id_iter->second)) { // terminal not in remaining
214  add_vertex_to_graph(next);
215  add_edge_to_graph(v, next);
216  retrieve_solution_split(next, remaining);
217  } else { // terminal
218  add_edge_to_graph(v, next);
219  remaining.flip(terminal_id_iter->second);
220  retrieve_solution_connect(next, remaining);
221  }
222  }
223 
228  void retrieve_solution_split(Vertex v, TerminalsBitSet remaining) {
229  if (remaining.none()) return;
230  TerminalsBitSet split =
231  m_best_split.at(code_state(v, remaining)).second;
232  retrieve_solution_connect(v, split);
233  retrieve_solution_connect(v, remaining ^ split);
234  }
235 
239  State code_state(Vertex v, TerminalsBitSet remaining) {
240  // TODO: can be optimized
241  return std::make_pair(v, remaining);
242  }
243 
247  struct state_hash {
248  std::size_t operator()(const State &k) const {
249  return boost::hash<Vertex>()(k.first) ^
250  (std::hash<TerminalsBitSet>()(k.second) << 1);
251  }
252  };
253 
257  void add_edge_to_graph(Vertex u, Vertex w) {
258  Edge e = std::make_pair(u, w);
259  m_edges.push_back(e);
260  }
261 
265  void add_vertex_to_graph(Vertex v) { m_steiner_elements.insert(v); }
266 
270  int smallest_bit(TerminalsBitSet mask) {
271  int k = 0;
272  while (!mask.test(k)) ++k;
273  return k;
274  }
275 
276  const Metric &m_cost_map; // stores the cost for each edge
277  const Terminals &m_terminals; // terminals to be connected
278  const NonTerminals &m_non_terminals; // list of all non-terminals
279 
280  Dist m_cost; // cost of optimal Steiner Tree
281  steiner_elements m_steiner_elements; // non-terminals selected for spanning
282  // tree
283  std::vector<Edge> m_edges; // edges spanning the component
284 
285  std::unordered_map<Vertex, int, boost::hash<Vertex>> m_elements_map; // maps Vertex to position
286  // in m_terminals vector
287  using StateV = std::pair<Dist, Vertex>;
288  using StateBM = std::pair<Dist, TerminalsBitSet>;
289  std::unordered_map<State, StateV, state_hash> m_best_cand; // stores result
290  // of dreyfus
291  // method for
292  // given state
293  std::unordered_map<State, StateBM, state_hash> m_best_split; // stores
294  // result of
295  // wagner
296  // method for
297  // given state
298 };
299 
307 template <unsigned int TerminalsLimit = 32, typename Metric, typename Terminals, typename NonTerminals>
308 dreyfus_wagner<Metric, Terminals, NonTerminals, TerminalsLimit>
309 make_dreyfus_wagner(const Metric &metric, const Terminals &terminals,
310  const NonTerminals &non_terminals) {
312  metric, terminals, non_terminals);
313 }
314 
315 } // paal
316 
317 #endif // PAAL_DREYFUS_WAGNER_HPP
dreyfus_wagner< Metric, Terminals, NonTerminals, TerminalsLimit > make_dreyfus_wagner(const Metric &metric, const Terminals &terminals, const NonTerminals &non_terminals)
Creates a dreyfus_wagner object.
Dist get_cost() const
bool best(Solution &solution, ContinueOnSuccess on_success, components...comps)
This local search chooses the best possible move and applies it to the solution. Note that this strat...
Terminals
enum indicates if given color represents terminal or NONTERMINAL.
const steiner_elements & get_steiner_elements() const
dreyfus_wagner(const Metric &cost_map, const Terminals &term, const NonTerminals &non_terminals)
const std::vector< Edge > & get_edges() const
puretype(std::declval< Metric >()(std::declval< VertexType >(), std::declval< VertexType >())) DistanceType
Distance type.
void solve(int start=0)