All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends Macros Pages
steiner_tree.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2013 Maciej Andrejczuk
3 // 2014 Piotr Godlewski, Piotr Wygocki
4 //
5 // Distributed under the Boost Software License, Version 1.0. (See
6 // accompanying file LICENSE_1_0.txt or copy at
7 // http://www.boost.org/LICENSE_1_0.txt)
8 //=======================================================================
16 #ifndef PAAL_STEINER_TREE_HPP
17 #define PAAL_STEINER_TREE_HPP
18 
19 #define BOOST_RESULT_OF_USE_DECLTYPE
20 
21 
31 
32 #include <boost/random/discrete_distribution.hpp>
33 #include <boost/range/join.hpp>
34 #include <boost/range/algorithm/unique.hpp>
35 #include <boost/range/algorithm/sort.hpp>
36 #include <boost/range/algorithm/copy.hpp>
37 #include <boost/range/algorithm/find.hpp>
38 
39 #include <random>
40 #include <vector>
41 
42 namespace paal {
43 namespace ir {
44 
45 namespace {
46 struct steiner_tree_compare_traits {
47  static const double EPSILON;
48 };
49 
50 const double steiner_tree_compare_traits::EPSILON = 1e-10;
51 }
52 
53 
64 template<typename OrigMetric, typename Terminals, typename Result,
65  typename Strategy = steiner_tree_all_generator,
66  typename Oracle = lp::random_violated_separation_oracle>
67 class steiner_tree {
68 public:
70  using Vertex = typename MT::VertexType;
71  using Vertices = std::vector<Vertex>;
72  using Dist = typename MT::DistanceType;
73  using Edge = typename std::pair<Vertex, Vertex>;
78  MetricIdx &,
79  const VertexIndex &,
81 
82 private:
83  Terminals m_terminals; // terminals in current state
84  Terminals m_steiner_vertices; // vertices that are not terminals
85  VertexIndex m_terminals_index; // mapping terminals to numbers for 0 to n.
86  VertexIndex m_vertex_index; // mapping vertices to numbers for 0 to n.
87  MetricIdx m_cost_map_idx; // metric in current state (operates on indexes)
88  Metric m_cost_map; // metric in current state
89  steiner_components<Vertex, Dist> m_components; // components in current
90  // state
91  Strategy m_strategy; // strategy to generate the components
92  Result m_result_iterator; // list of selected Steiner Vertices
93  Vertices m_selected_elements; // list of selected Steiner Vertices
94  Compare m_compare; // comparison method
95 
96  std::unordered_map<int, lp::col_id> m_elements_map; // maps component_id ->
97  // col_id in LP
98  steiner_tree_violation_checker m_violation_checker;
99 
100  Oracle m_oracle;
101 
102 public:
106  steiner_tree(const OrigMetric& metric, const Terminals& terminals,
107  const Terminals& steiner_vertices, Result result,
108  const Strategy& strategy = Strategy{}, Oracle oracle = Oracle{}) :
109  m_terminals(terminals), m_steiner_vertices(steiner_vertices),
110  m_terminals_index(m_terminals),
111  m_vertex_index(boost::range::join(m_terminals, m_steiner_vertices)),
112  m_cost_map_idx(metric, boost::range::join(m_terminals, m_steiner_vertices)),
113  m_cost_map(m_cost_map_idx, m_vertex_index),
114  m_strategy(strategy), m_result_iterator(result),
115  m_compare(steiner_tree_compare_traits::EPSILON), m_oracle(oracle) {
116  }
117 
126  template <typename LP>
127  auto get_find_violation(LP & lp) {
128  using candidate = steiner_tree_violation_checker::Candidate;
129  return m_oracle([&](){return m_violation_checker.get_violation_candidates(*this, lp);},
130  [&](candidate c){return m_violation_checker.check_violation(c, *this);},
131  [&](candidate c){return m_violation_checker.add_violated_constraint(c, *this, lp);});
132  }
133 
137  void gen_components() {
138  m_strategy.gen_components(m_cost_map, m_terminals, m_steiner_vertices,
139  m_components);
140  }
141 
146  return m_components;
147  }
148 
152  const Terminals &get_terminals() const { return m_terminals; }
153 
157  auto get_terminal_idx(Vertex v) const -> decltype(m_terminals_index.get_idx(v)) {
158  return m_terminals_index.get_idx(v);
159  }
160 
164  void add_column_lp(int id, lp::col_id col) {
165  bool b = m_elements_map.insert(std::make_pair(id, col)).second;
166  assert(b);
167  }
168 
172  lp::col_id find_column_lp(int id) const { return m_elements_map.at(id); }
173 
177  void add_to_solution(const std::vector<Vertex>& steiner_elements) {
178  boost::copy(steiner_elements, std::back_inserter(m_selected_elements));
179  }
180 
184  void set_solution() {
185  boost::sort(m_selected_elements);
186  boost::copy(boost::unique(m_selected_elements), m_result_iterator);
187  }
188 
193  auto const & all_terminals = selected.get_terminals();
194  auto all_terminals_except_first = boost::make_iterator_range(++all_terminals.begin(), all_terminals.end());
195  assert(!boost::empty(all_terminals));
196  auto const & sink = all_terminals.front();
197  for (auto t : all_terminals_except_first) {
198  merge_vertices(sink, t);
199  auto ii = boost::range::find(m_terminals, t);
200  assert(ii != m_terminals.end());
201  m_terminals.erase(ii);
202  }
203  // Clean components, they will be generated once again
204  m_components.clear();
205  m_elements_map.clear();
206  m_terminals_index = VertexIndex(m_terminals);
207  }
208 
213  return m_compare;
214  }
215 
216 private:
220  auto get_idx(Vertex v) const -> decltype(m_vertex_index.get_idx(v)) {
221  return m_vertex_index.get_idx(v);
222  }
223 
227  void merge_vertices(Vertex u_vertex, Vertex w_vertex) {
228  auto all_elements = boost::range::join(m_terminals, m_steiner_vertices);
229  auto u = get_idx(u_vertex);
230  auto w = get_idx(w_vertex);
231  for (auto i_vertex: all_elements) {
232  for (auto j_vertex: all_elements) {
233  auto i = get_idx(i_vertex);
234  auto j = get_idx(j_vertex);
235  assign_min(m_cost_map_idx(i, j),
236  m_cost_map_idx(i, u) + m_cost_map_idx(w, j));
237  }
238  }
239  }
240 
241 };
242 
243 
251  template <typename Problem, typename LP>
252  void operator()(Problem &problem, LP &lp) {
253  lp.clear();
254  lp.set_lp_name("steiner tree");
255  problem.gen_components();
256  lp.set_optimization_type(lp::MINIMIZE);
257  add_variables(problem, lp);
258  }
259 
260  private:
264  template <typename Problem, typename LP>
265  void add_variables(Problem &problem, LP &lp) {
266  for (int i = 0; i < problem.get_components().size(); ++i) {
267  lp::col_id col = lp.add_column(
268  problem.get_components().find(i).get_cost(), 0, 1);
269  problem.add_column_lp(i, col);
270  }
271  }
272 };
273 
278  std::default_random_engine m_rng;
279  public:
280  steiner_tree_round_condition(std::default_random_engine = std::default_random_engine{}) {}
281 
286  template <typename Problem, typename LP>
287  void operator()(Problem &problem, LP &lp) {
288  auto size = problem.get_components().size();
289  std::vector<double> weights;
290  weights.reserve(size);
291  for (auto i : paal::irange(size)) {
292  lp::col_id cId = problem.find_column_lp(i);
293  weights.push_back(lp.get_col_value(cId));
294  }
295 
296  auto selected = boost::random::discrete_distribution<>(weights)(m_rng);
297  auto const &comp = problem.get_components().find(selected);
298  problem.add_to_solution(comp.get_steiner_elements());
299  problem.update_graph(comp);
300  steiner_tree_init()(problem, lp);
301  }
302 };
303 
309  template<typename Problem, typename LP>
310  bool operator()(Problem& problem, LP &) {
311  return problem.get_terminals().size() < 2;
312  }
313 };
314 
322  template <typename Problem, typename GetSolution>
323  void operator()(Problem & problem, const GetSolution &) {
324  problem.set_solution();
325  }
326 };
327 
331 template<typename Oracle = lp::random_violated_separation_oracle,
332  typename OrigMetric, typename Terminals, typename Result, typename Strategy>
334  const OrigMetric& metric, const Terminals& terminals,
335  const Terminals& steiner_vertices, Result result, const Strategy& strategy,
336  Oracle oracle = Oracle()) {
338  terminals, steiner_vertices, result, strategy, oracle);
339 }
340 
341 template <typename Init = steiner_tree_init,
342  typename RoundCondition = steiner_tree_round_condition,
343  typename RelaxCondition = utils::always_false,
344  typename SetSolution = steiner_tree_set_solution,
345  typename SolveLPToExtremePoint = row_generation_solve_lp<>,
346  typename ResolveLPToExtremePoint = row_generation_solve_lp<>,
347  typename StopCondition = steiner_tree_stop_condition>
348 using steiner_tree_ir_components =
349  IRcomponents<Init, RoundCondition, RelaxCondition, SetSolution,
350  SolveLPToExtremePoint, ResolveLPToExtremePoint, StopCondition>;
351 
355 template <typename Oracle = lp::random_violated_separation_oracle,
356  typename Strategy = steiner_tree_all_generator,
357  typename OrigMetric,
358  typename Terminals,
359  typename Result,
360  typename IRcomponents = steiner_tree_ir_components<>,
361  typename Visitor = trivial_visitor>
362 lp::problem_type steiner_tree_iterative_rounding(const OrigMetric& metric, const Terminals& terminals,
363  const Terminals& steiner_vertices, Result result, Strategy strategy = Strategy{},
364  IRcomponents comps = IRcomponents{}, Oracle oracle = Oracle{},
365  Visitor visitor = Visitor{}) {
366 
367  auto steiner = paal::ir::make_steiner_tree(metric, terminals, steiner_vertices, result, strategy, oracle);
368  auto res = paal::ir::solve_dependent_iterative_rounding(steiner, std::move(comps), std::move(visitor));
369  return res.first;
370 }
371 
372 }
373 }
374 #endif // PAAL_STEINER_TREE_HPP
void operator()(Problem &problem, const GetSolution &)
Violation check_violation(Candidate candidate, const Problem &problem)
The common LP solvers base class. Responsible for:
Definition: lp_base.hpp:55
col_id add_column(double cost_coef=0, double lb=0., double ub=lp_traits::PLUS_INF, const std::string &name="")
Definition: lp_base.hpp:92
void operator()(Problem &problem, LP &lp)
functor return false
Definition: functors.hpp:222
steiner_tree(const OrigMetric &metric, const Terminals &terminals, const Terminals &steiner_vertices, Result result, const Strategy &strategy=Strategy{}, Oracle oracle=Oracle{})
utils::compare< double > get_compare() const
auto get_violation_candidates(const Problem &problem, const LP &lp) -> decltype(irange(problem.get_terminals().size()))
Idx get_idx(const T &t) const
gets index of element t
Definition: bimap.hpp:153
void add_to_solution(const std::vector< Vertex > &steiner_elements)
problem_type
LP problem type.
const steiner_components< Vertex, Dist > & get_components() const
auto get_find_violation(LP &lp)
Class represents k-components of Steiner Tree. Component is a subtree whose terminals coincide with l...
void set_lp_name(const std::string problem_name)
Definition: lp_base.hpp:78
bool operator()(Problem &problem, LP &)
Checks if the IR algorithm should terminate.
The class for solving the Steiner Tree problem using Iterative Rounding.
const Vertices & get_terminals() const
auto irange(T begin, T end)
irange
Definition: irange.hpp:22
lp::problem_type steiner_tree_iterative_rounding(const OrigMetric &metric, const Terminals &terminals, const Terminals &steiner_vertices, Result result, Strategy strategy=Strategy{}, IRcomponents comps=IRcomponents{}, Oracle oracle=Oracle{}, Visitor visitor=Visitor{})
Solves the Steiner Tree problem using Iterative Rounding.
Violations checker for the separation oracle in the steiner tree problem.
lp::col_id find_column_lp(int id) const
void assign_min(T &t, const T &u)
steiner_tree< OrigMetric, Terminals, Result, Strategy, Oracle > make_steiner_tree(const OrigMetric &metric, const Terminals &terminals, const Terminals &steiner_vertices, Result result, const Strategy &strategy, Oracle oracle=Oracle())
Terminals
enum indicates if given color represents terminal or NONTERMINAL.
auto get_terminal_idx(Vertex v) const -> decltype(m_terminals_index.get_idx(v))
typename components::type< Args...> IRcomponents
Iterative rounding components.
const Terminals & get_terminals() const
double get_col_value(col_id col) const
Definition: lp_base.hpp:228
void add_violated_constraint(Candidate violating_terminal, const Problem &problem, LP &lp)
puretype(std::declval< Metric >()(std::declval< VertexType >(), std::declval< VertexType >())) DistanceType
Distance type.
void operator()(Problem &problem, LP &lp)
IRResult solve_dependent_iterative_rounding(Problem &problem, IRcomponents components, Visitor visitor=Visitor())
Solves an Iterative Rounding problem with dependent rounding.
void add_column_lp(int id, lp::col_id col)
void update_graph(const steiner_component< Vertex, Dist > &selected)