All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends Macros Pages
steiner_tree_oracle.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c)
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See
5 // accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt)
7 //=======================================================================
15 #ifndef PAAL_STEINER_TREE_ORACLE_HPP
16 #define PAAL_STEINER_TREE_ORACLE_HPP
17 
19 #include "paal/utils/irange.hpp"
20 
21 #include <boost/optional.hpp>
22 
23 #include <unordered_map>
24 #include <vector>
25 
26 namespace paal {
27 namespace ir {
28 
35  using AuxEdge = min_cut_finder::Edge;
36  using AuxVertex = min_cut_finder::Vertex;
37  using AuxEdgeList = std::vector<AuxEdge>;
38  using Violation = boost::optional<double>;
39 
40  public:
41  using Candidate = AuxVertex;
42  steiner_tree_violation_checker() : m_current_graph_size(-1) {}
43 
47  template <typename Problem, typename LP>
48  auto get_violation_candidates(const Problem &problem, const LP &lp)
49  ->decltype(irange(problem.get_terminals().size())) {
50 
51  int graph_size = problem.get_terminals().size();
52  if (graph_size != m_current_graph_size) {
53  // Graph has changed, construct new oracle
54  m_current_graph_size = graph_size;
55  m_root = select_root(problem.get_terminals());
56  create_auxiliary_digraph(problem, lp);
57  } else {
58  update_auxiliary_digraph(problem, lp);
59  }
60  //TODO - rethink - why do we return the whole collection instead of returning m_root see check_violation implementation
61  return irange(problem.get_terminals().size());
62  }
63 
68  template <typename Problem>
69  Violation check_violation(Candidate candidate, const Problem &problem) {
70  if (candidate == m_root) {
71  return Violation{};
72  }
73 
74  double violation = find_violation(candidate);
75  if (problem.get_compare().g(violation, 0)) {
76  return violation;
77  } else {
78  return Violation{};
79  }
80  }
81 
87  template <typename Problem, typename LP>
88  void add_violated_constraint(Candidate violating_terminal,
89  const Problem &problem, LP &lp) {
90  if (std::make_pair(violating_terminal, m_root) !=
91  m_min_cut.get_last_cut()) {
92  find_violation(violating_terminal);
93  }
94 
95  auto const &components = problem.get_components();
97  for (int i = 0; i < components.size(); ++i) {
98  auto u = m_artif_vertices[i];
99  int ver = components.find_version(i);
100  auto v = m_terminals_to_aux[problem.get_terminal_idx(components.find(i).get_sink(ver))];
101  if (m_min_cut.is_in_source_set(u) &&
102  !m_min_cut.is_in_source_set(v)) {
103  expr += problem.find_column_lp(i);
104  }
105  }
106  lp.add_row(std::move(expr) >= 1);
107  }
108 
109  private:
110 
118  template <typename Problem, typename LP>
119  void create_auxiliary_digraph(Problem &problem, const LP &lp) {
120  m_min_cut.init(0);
121  m_artif_vertices.clear();
122  m_terminals_to_aux.clear();
123  for (auto term : irange(problem.get_terminals().size())) {
124  m_terminals_to_aux[term] = m_min_cut.add_vertex_to_graph();
125  }
126  auto const &components = problem.get_components();
127 
128  for (int i = 0; i < components.size(); ++i) {
129  AuxVertex new_v = m_min_cut.add_vertex_to_graph();
130  m_artif_vertices[i] = new_v;
131  int ver = components.find_version(i);
132  auto sink = components.find(i).get_sink(ver);
133  for (auto w : boost::make_iterator_range(components.find(i)
134  .get_terminals())) {
135  if (w != sink) {
136  double INF = std::numeric_limits<double>::max();
137  m_min_cut.add_edge_to_graph(m_terminals_to_aux[problem.get_terminal_idx(w)], new_v,
138  INF);
139  } else {
140  lp::col_id x = problem.find_column_lp(i);
141  double col_val = lp.get_col_value(x);
142  m_min_cut.add_edge_to_graph(new_v,
143  m_terminals_to_aux[problem.get_terminal_idx(sink)],
144  col_val);
145  }
146  }
147  }
148  }
149 
154  template <typename Problem, typename LP>
155  void update_auxiliary_digraph(Problem &problem, const LP &lp) {
156  auto const &components = problem.get_components();
157  for (int i = 0; i < components.size(); ++i) {
158  auto component_v = m_artif_vertices[i];
159  int ver = components.find_version(i);
160  auto sink = components.find(i).get_sink(ver);
161  double col_val = lp.get_col_value(problem.find_column_lp(i));
162  m_min_cut.set_capacity(component_v,
163  m_terminals_to_aux[problem.get_terminal_idx(sink)], col_val);
164  }
165  }
166 
171  template <typename Terminals>
172  AuxVertex select_root(const Terminals &terminals) {
173  // TODO: Maybe it's better to select random vertex rather than first
174  return 0;
175  }
176 
181  double find_violation(AuxVertex src) {
182  double min_cut_weight = m_min_cut.find_min_cut(
183  m_terminals_to_aux[src], m_terminals_to_aux[m_root]);
184  return 1 - min_cut_weight;
185  }
186 
187  AuxVertex m_root; // root vertex, sink of all max-flows
188  int m_current_graph_size; // size of current graph
189 
190  // maps component_id to aux_graph vertex
191  std::unordered_map<int, AuxVertex> m_artif_vertices;
192 
193  // maps terminals to aux_graph vertices
194  std::unordered_map<AuxVertex, AuxVertex> m_terminals_to_aux;
195 
196  min_cut_finder m_min_cut;
197 };
198 
199 }
200 }
201 #endif // PAAL_STEINER_TREE_ORACLE_HPP
std::pair< Edge, Edge > add_edge_to_graph(Vertex src, Vertex trg, double cap, double rev_cap=0.)
Definition: min_cut.hpp:70
bool is_in_source_set(Vertex v) const
Definition: min_cut.hpp:109
void set_capacity(Edge e, double cap)
Definition: min_cut.hpp:138
Violation check_violation(Candidate candidate, const Problem &problem)
The common LP solvers base class. Responsible for:
Definition: lp_base.hpp:55
auto get_violation_candidates(const Problem &problem, const LP &lp) -> decltype(irange(problem.get_terminals().size()))
auto irange(T begin, T end)
irange
Definition: irange.hpp:22
Violations checker for the separation oracle in the steiner tree problem.
double find_min_cut(Vertex src, Vertex trg)
Definition: min_cut.hpp:95
Terminals
enum indicates if given color represents terminal or NONTERMINAL.
double get_col_value(col_id col) const
Definition: lp_base.hpp:228
row_id add_row(const double_bounded_expression &constraint=double_bounded_expression{}, const std::string &name="")
Definition: lp_base.hpp:109
void add_violated_constraint(Candidate violating_terminal, const Problem &problem, LP &lp)
Vertex add_vertex_to_graph()
Definition: min_cut.hpp:58
void init(int vertices_num)
Definition: min_cut.hpp:46
std::pair< Vertex, Vertex > get_last_cut() const
Definition: min_cut.hpp:128