graph.hpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. // Copyright (c) 2019, QuantStack and Mamba Contributors
  2. //
  3. // Distributed under the terms of the BSD 3-Clause License.
  4. //
  5. // The full license is in the file LICENSE, distributed with this software.
  6. #ifndef MAMBA_UTIL_GRAPH_HPP
  7. #define MAMBA_UTIL_GRAPH_HPP
  8. #include <algorithm>
  9. #include <functional>
  10. #include <iterator>
  11. #include <map>
  12. #include <utility>
  13. #include <vector>
  14. #include "flat_set.hpp"
  15. namespace mamba::util
  16. {
  17. // Simplified implementation of a directed graph
  18. template <typename Node, typename Derived>
  19. class DiGraphBase
  20. {
  21. public:
  22. using node_t = Node;
  23. using node_id = std::size_t;
  24. using node_map = std::map<node_id, node_t>;
  25. using node_id_list = flat_set<node_id>;
  26. using adjacency_list = std::vector<node_id_list>;
  27. node_id add_node(const node_t& value);
  28. node_id add_node(node_t&& value);
  29. bool add_edge(node_id from, node_id to);
  30. bool remove_edge(node_id from, node_id to);
  31. bool remove_node(node_id id);
  32. bool empty() const;
  33. std::size_t number_of_nodes() const noexcept;
  34. std::size_t number_of_edges() const noexcept;
  35. std::size_t in_degree(node_id id) const noexcept;
  36. std::size_t out_degree(node_id id) const noexcept;
  37. const node_map& nodes() const;
  38. const node_t& node(node_id id) const;
  39. node_t& node(node_id id);
  40. const node_id_list& successors(node_id id) const;
  41. const adjacency_list& successors() const;
  42. const node_id_list& predecessors(node_id id) const;
  43. const adjacency_list& predecessors() const;
  44. bool has_node(node_id id) const;
  45. bool has_edge(node_id from, node_id to) const;
  46. // TODO C++20 better to return a range since this search cannot be interupted from the
  47. // visitor
  48. template <typename UnaryFunc>
  49. UnaryFunc for_each_node_id(UnaryFunc func) const;
  50. template <typename BinaryFunc>
  51. BinaryFunc for_each_edge_id(BinaryFunc func) const;
  52. template <typename UnaryFunc>
  53. UnaryFunc for_each_leaf_id(UnaryFunc func) const;
  54. template <typename UnaryFunc>
  55. UnaryFunc for_each_leaf_id_from(node_id source, UnaryFunc func) const;
  56. template <typename UnaryFunc>
  57. UnaryFunc for_each_root_id(UnaryFunc func) const;
  58. template <typename UnaryFunc>
  59. UnaryFunc for_each_root_id_from(node_id source, UnaryFunc func) const;
  60. // TODO C++20 better to return a range since this search cannot be interupted from the
  61. // visitor
  62. template <class V>
  63. void depth_first_search(V& visitor, node_id start = node_id(0), bool reverse = false) const;
  64. protected:
  65. using derived_t = Derived;
  66. DiGraphBase() = default;
  67. DiGraphBase(const DiGraphBase&) = default;
  68. DiGraphBase(DiGraphBase&&) = default;
  69. DiGraphBase& operator=(const DiGraphBase&) = default;
  70. DiGraphBase& operator=(DiGraphBase&&) = default;
  71. ~DiGraphBase() = default;
  72. node_id number_of_node_id() const noexcept;
  73. Derived& derived_cast();
  74. const Derived& derived_cast() const;
  75. private:
  76. enum class visited
  77. {
  78. no,
  79. ongoing,
  80. yes
  81. };
  82. using visited_list = std::vector<visited>;
  83. template <class V>
  84. node_id add_node_impl(V&& value);
  85. template <class V>
  86. void depth_first_search_impl(
  87. V& visitor,
  88. node_id node,
  89. visited_list& status,
  90. const adjacency_list& successors
  91. ) const;
  92. // Source of truth for exsising nodes
  93. node_map m_node_map;
  94. // May contains empty slots after `remove_node`
  95. adjacency_list m_predecessors;
  96. // May contains empty slots after `remove_node`
  97. adjacency_list m_successors;
  98. std::size_t m_number_of_edges = 0;
  99. };
  100. template <typename Node, typename Derived>
  101. auto is_reachable(
  102. const DiGraphBase<Node, Derived>& graph,
  103. typename DiGraphBase<Node, Derived>::node_id source,
  104. typename DiGraphBase<Node, Derived>::node_id target
  105. ) -> bool;
  106. template <class G>
  107. class default_visitor
  108. {
  109. public:
  110. using graph_t = G;
  111. using node_id = typename graph_t::node_id;
  112. void start_node(node_id, const graph_t&)
  113. {
  114. }
  115. void finish_node(node_id, const graph_t&)
  116. {
  117. }
  118. void start_edge(node_id, node_id, const graph_t&)
  119. {
  120. }
  121. void tree_edge(node_id, node_id, const graph_t&)
  122. {
  123. }
  124. void back_edge(node_id, node_id, const graph_t&)
  125. {
  126. }
  127. void forward_or_cross_edge(node_id, node_id, const graph_t&)
  128. {
  129. }
  130. void finish_edge(node_id, node_id, const graph_t&)
  131. {
  132. }
  133. };
  134. template <typename Node, typename Edge = void>
  135. class DiGraph : private DiGraphBase<Node, DiGraph<Node, Edge>>
  136. {
  137. public:
  138. using Base = DiGraphBase<Node, DiGraph<Node, Edge>>;
  139. using typename Base::adjacency_list;
  140. using typename Base::node_id;
  141. using typename Base::node_id_list;
  142. using typename Base::node_map;
  143. using typename Base::node_t;
  144. using edge_t = Edge;
  145. using edge_id = std::pair<node_id, node_id>;
  146. using edge_map = std::map<edge_id, edge_t>;
  147. using Base::empty;
  148. using Base::has_edge;
  149. using Base::has_node;
  150. using Base::in_degree;
  151. using Base::node;
  152. using Base::nodes;
  153. using Base::number_of_edges;
  154. using Base::number_of_nodes;
  155. using Base::out_degree;
  156. using Base::predecessors;
  157. using Base::successors;
  158. using Base::for_each_edge_id;
  159. using Base::for_each_leaf_id;
  160. using Base::for_each_leaf_id_from;
  161. using Base::for_each_node_id;
  162. using Base::for_each_root_id;
  163. using Base::for_each_root_id_from;
  164. using Base::depth_first_search;
  165. using Base::add_node;
  166. bool add_edge(node_id from, node_id to, const edge_t& data);
  167. bool add_edge(node_id from, node_id to, edge_t&& data);
  168. bool remove_edge(node_id from, node_id to);
  169. bool remove_node(node_id id);
  170. const edge_map& edges() const;
  171. const edge_t& edge(node_id from, node_id to) const;
  172. const edge_t& edge(edge_id edge) const;
  173. edge_t& edge(node_id from, node_id to);
  174. edge_t& edge(edge_id edge);
  175. private:
  176. friend class DiGraphBase<Node, DiGraph<Node, Edge>>; // required for private CRTP
  177. template <typename T>
  178. bool add_edge_impl(node_id from, node_id to, T&& data);
  179. edge_map m_edges;
  180. };
  181. template <typename Node>
  182. class DiGraph<Node, void> : public DiGraphBase<Node, DiGraph<Node, void>>
  183. {
  184. };
  185. /********************************
  186. * DiGraphBase Implementation *
  187. ********************************/
  188. template <typename N, typename G>
  189. bool DiGraphBase<N, G>::empty() const
  190. {
  191. return number_of_nodes() == 0;
  192. }
  193. template <typename N, typename G>
  194. auto DiGraphBase<N, G>::number_of_nodes() const noexcept -> std::size_t
  195. {
  196. return m_node_map.size();
  197. }
  198. template <typename N, typename G>
  199. auto DiGraphBase<N, G>::number_of_edges() const noexcept -> std::size_t
  200. {
  201. return m_number_of_edges;
  202. }
  203. template <typename N, typename G>
  204. auto DiGraphBase<N, G>::in_degree(node_id id) const noexcept -> std::size_t
  205. {
  206. return m_predecessors[id].size();
  207. }
  208. template <typename N, typename G>
  209. auto DiGraphBase<N, G>::out_degree(node_id id) const noexcept -> std::size_t
  210. {
  211. return m_successors[id].size();
  212. }
  213. template <typename N, typename G>
  214. auto DiGraphBase<N, G>::nodes() const -> const node_map&
  215. {
  216. return m_node_map;
  217. }
  218. template <typename N, typename G>
  219. auto DiGraphBase<N, G>::node(node_id id) const -> const node_t&
  220. {
  221. return m_node_map.at(id);
  222. }
  223. template <typename N, typename G>
  224. auto DiGraphBase<N, G>::node(node_id id) -> node_t&
  225. {
  226. return m_node_map.at(id);
  227. }
  228. template <typename N, typename G>
  229. auto DiGraphBase<N, G>::successors(node_id id) const -> const node_id_list&
  230. {
  231. return m_successors[id];
  232. }
  233. template <typename N, typename G>
  234. auto DiGraphBase<N, G>::successors() const -> const adjacency_list&
  235. {
  236. return m_successors;
  237. }
  238. template <typename N, typename G>
  239. auto DiGraphBase<N, G>::predecessors(node_id id) const -> const node_id_list&
  240. {
  241. return m_predecessors[id];
  242. }
  243. template <typename N, typename G>
  244. auto DiGraphBase<N, G>::predecessors() const -> const adjacency_list&
  245. {
  246. return m_predecessors;
  247. }
  248. template <typename N, typename G>
  249. auto DiGraphBase<N, G>::has_node(node_id id) const -> bool
  250. {
  251. return nodes().count(id) > 0;
  252. }
  253. template <typename N, typename G>
  254. auto DiGraphBase<N, G>::has_edge(node_id from, node_id to) const -> bool
  255. {
  256. return has_node(from) && successors(from).contains(to);
  257. }
  258. template <typename N, typename G>
  259. auto DiGraphBase<N, G>::add_node(const node_t& value) -> node_id
  260. {
  261. return add_node_impl(value);
  262. }
  263. template <typename N, typename G>
  264. auto DiGraphBase<N, G>::add_node(node_t&& value) -> node_id
  265. {
  266. return add_node_impl(std::move(value));
  267. }
  268. template <typename N, typename G>
  269. template <class V>
  270. auto DiGraphBase<N, G>::add_node_impl(V&& value) -> node_id
  271. {
  272. const node_id id = number_of_node_id();
  273. m_node_map.emplace(id, std::forward<V>(value));
  274. m_successors.push_back(node_id_list());
  275. m_predecessors.push_back(node_id_list());
  276. return id;
  277. }
  278. template <typename N, typename G>
  279. bool DiGraphBase<N, G>::remove_node(node_id id)
  280. {
  281. if (!has_node(id))
  282. {
  283. return false;
  284. }
  285. const auto succs = successors(id); // Cannot iterate on object being modified
  286. for (const auto& to : succs)
  287. {
  288. remove_edge(id, to);
  289. }
  290. const auto preds = predecessors(id); // Cannot iterate on object being modified
  291. for (const auto& from : preds)
  292. {
  293. remove_edge(from, id);
  294. }
  295. m_node_map.erase(id);
  296. return true;
  297. }
  298. template <typename N, typename G>
  299. bool DiGraphBase<N, G>::add_edge(node_id from, node_id to)
  300. {
  301. if (has_edge(from, to))
  302. {
  303. return false;
  304. }
  305. m_successors[from].insert(to);
  306. m_predecessors[to].insert(from);
  307. ++m_number_of_edges;
  308. return true;
  309. }
  310. template <typename N, typename G>
  311. bool DiGraphBase<N, G>::remove_edge(node_id from, node_id to)
  312. {
  313. if (!has_edge(from, to))
  314. {
  315. return false;
  316. }
  317. m_successors[from].erase(to);
  318. m_predecessors[to].erase(from);
  319. --m_number_of_edges;
  320. return true;
  321. }
  322. template <typename N, typename G>
  323. template <typename UnaryFunc>
  324. UnaryFunc DiGraphBase<N, G>::for_each_node_id(UnaryFunc func) const
  325. {
  326. for (const auto& [i, _] : m_node_map)
  327. {
  328. func(i);
  329. }
  330. return func;
  331. }
  332. template <typename N, typename G>
  333. template <typename BinaryFunc>
  334. BinaryFunc DiGraphBase<N, G>::for_each_edge_id(BinaryFunc func) const
  335. {
  336. for_each_node_id(
  337. [&](node_id i)
  338. {
  339. for (node_id j : successors(i))
  340. {
  341. func(i, j);
  342. }
  343. }
  344. );
  345. return func;
  346. }
  347. template <typename N, typename G>
  348. template <typename UnaryFunc>
  349. UnaryFunc DiGraphBase<N, G>::for_each_leaf_id(UnaryFunc func) const
  350. {
  351. for_each_node_id(
  352. [&](node_id i)
  353. {
  354. if (out_degree(i) == 0)
  355. {
  356. func(i);
  357. }
  358. }
  359. );
  360. return func;
  361. }
  362. template <typename N, typename G>
  363. template <typename UnaryFunc>
  364. UnaryFunc DiGraphBase<N, G>::for_each_root_id(UnaryFunc func) const
  365. {
  366. for_each_node_id(
  367. [&](node_id i)
  368. {
  369. if (in_degree(i) == 0)
  370. {
  371. func(i);
  372. }
  373. }
  374. );
  375. return func;
  376. }
  377. template <typename N, typename G>
  378. template <typename UnaryFunc>
  379. UnaryFunc DiGraphBase<N, G>::for_each_leaf_id_from(node_id source, UnaryFunc func) const
  380. {
  381. struct LeafVisitor : default_visitor<derived_t>
  382. {
  383. UnaryFunc& m_func;
  384. LeafVisitor(UnaryFunc& func)
  385. : m_func{ func }
  386. {
  387. }
  388. void start_node(node_id n, const derived_t& g)
  389. {
  390. if (g.out_degree(n) == 0)
  391. {
  392. m_func(n);
  393. }
  394. }
  395. };
  396. auto visitor = LeafVisitor(func);
  397. depth_first_search(visitor, source);
  398. return func;
  399. }
  400. template <typename N, typename G>
  401. template <typename UnaryFunc>
  402. UnaryFunc DiGraphBase<N, G>::for_each_root_id_from(node_id source, UnaryFunc func) const
  403. {
  404. struct RootVisitor : default_visitor<derived_t>
  405. {
  406. UnaryFunc& m_func;
  407. RootVisitor(UnaryFunc& func)
  408. : m_func{ func }
  409. {
  410. }
  411. void start_node(node_id n, const derived_t& g)
  412. {
  413. if (g.in_degree(n) == 0)
  414. {
  415. m_func(n);
  416. }
  417. }
  418. };
  419. auto visitor = RootVisitor(func);
  420. depth_first_search(visitor, source, /*reverse*/ true);
  421. return func;
  422. }
  423. template <typename N, typename G>
  424. template <class V>
  425. void DiGraphBase<N, G>::depth_first_search(V& visitor, node_id node, bool reverse) const
  426. {
  427. if (!empty())
  428. {
  429. visited_list status(number_of_node_id(), visited::no);
  430. depth_first_search_impl(visitor, node, status, reverse ? m_predecessors : m_successors);
  431. }
  432. }
  433. template <typename N, typename G>
  434. template <class V>
  435. void DiGraphBase<N, G>::depth_first_search_impl(
  436. V& visitor,
  437. node_id node,
  438. visited_list& status,
  439. const adjacency_list& successors
  440. ) const
  441. {
  442. status[node] = visited::ongoing;
  443. visitor.start_node(node, derived_cast());
  444. for (auto child : successors[node])
  445. {
  446. visitor.start_edge(node, child, derived_cast());
  447. if (status[child] == visited::no)
  448. {
  449. visitor.tree_edge(node, child, derived_cast());
  450. depth_first_search_impl(visitor, child, status, successors);
  451. }
  452. else if (status[child] == visited::ongoing)
  453. {
  454. visitor.back_edge(node, child, derived_cast());
  455. }
  456. else
  457. {
  458. visitor.forward_or_cross_edge(node, child, derived_cast());
  459. }
  460. visitor.finish_edge(node, child, derived_cast());
  461. }
  462. status[node] = visited::yes;
  463. visitor.finish_node(node, derived_cast());
  464. }
  465. template <typename N, typename G>
  466. auto DiGraphBase<N, G>::number_of_node_id() const noexcept -> node_id
  467. {
  468. // Not number_of_nodes because due to remove nodes it may be larger
  469. return m_successors.size();
  470. }
  471. template <typename N, typename G>
  472. auto DiGraphBase<N, G>::derived_cast() -> derived_t&
  473. {
  474. return static_cast<derived_t&>(*this);
  475. }
  476. template <typename N, typename G>
  477. auto DiGraphBase<N, G>::derived_cast() const -> const derived_t&
  478. {
  479. return static_cast<const derived_t&>(*this);
  480. }
  481. /*******************************
  482. * Algorithms implementation *
  483. *******************************/
  484. template <typename Graph>
  485. auto
  486. is_reachable(const Graph& graph, typename Graph::node_id source, typename Graph::node_id target)
  487. -> bool
  488. {
  489. struct : default_visitor<Graph>
  490. {
  491. using node_id = typename Graph::node_id;
  492. node_id target;
  493. bool target_visited = false;
  494. void start_node(node_id node, const Graph&)
  495. {
  496. target_visited = target_visited || (node == target);
  497. }
  498. } visitor{ {}, target };
  499. graph.depth_first_search(visitor, source);
  500. return visitor.target_visited;
  501. }
  502. /*********************************
  503. * DiGraph Edge Implementation *
  504. *********************************/
  505. template <typename N, typename E>
  506. bool DiGraph<N, E>::add_edge(node_id from, node_id to, const edge_t& data)
  507. {
  508. return add_edge_impl(from, to, data);
  509. }
  510. template <typename N, typename E>
  511. bool DiGraph<N, E>::add_edge(node_id from, node_id to, edge_t&& data)
  512. {
  513. return add_edge_impl(from, to, std::move(data));
  514. }
  515. template <typename N, typename E>
  516. template <typename T>
  517. bool DiGraph<N, E>::add_edge_impl(node_id from, node_id to, T&& data)
  518. {
  519. if (const bool added = Base::add_edge(from, to); added)
  520. {
  521. auto l_edge_id = std::pair(from, to);
  522. m_edges.insert(std::pair(l_edge_id, std::forward<T>(data)));
  523. return true;
  524. }
  525. return false;
  526. }
  527. template <typename N, typename E>
  528. bool DiGraph<N, E>::remove_edge(node_id from, node_id to)
  529. {
  530. m_edges.erase({ from, to }); // No-op if edge does not exists
  531. return Base::remove_edge(from, to);
  532. }
  533. template <typename N, typename E>
  534. bool DiGraph<N, E>::remove_node(node_id id)
  535. {
  536. // No-op if edge does not exists
  537. for (const auto& to : successors(id))
  538. {
  539. m_edges.erase({ id, to });
  540. }
  541. for (const auto& from : predecessors(id))
  542. {
  543. m_edges.erase({ from, id });
  544. }
  545. return Base::remove_node(id);
  546. }
  547. template <typename N, typename E>
  548. auto DiGraph<N, E>::edges() const -> const edge_map&
  549. {
  550. return m_edges;
  551. }
  552. template <typename N, typename E>
  553. auto DiGraph<N, E>::edge(edge_id edge) const -> const edge_t&
  554. {
  555. return m_edges.at(edge);
  556. }
  557. template <typename N, typename E>
  558. auto DiGraph<N, E>::edge(node_id from, node_id to) const -> const edge_t&
  559. {
  560. return edge({ from, to });
  561. }
  562. template <typename N, typename E>
  563. auto DiGraph<N, E>::edge(edge_id edge) -> edge_t&
  564. {
  565. return m_edges[edge];
  566. }
  567. template <typename N, typename E>
  568. auto DiGraph<N, E>::edge(node_id from, node_id to) -> edge_t&
  569. {
  570. return edge({ from, to });
  571. }
  572. }
  573. #endif