diff --git a/docs/source/api/algorithm_functions/dag_algorithms.rst b/docs/source/api/algorithm_functions/dag_algorithms.rst index e3830ac90a..da2277380c 100644 --- a/docs/source/api/algorithm_functions/dag_algorithms.rst +++ b/docs/source/api/algorithm_functions/dag_algorithms.rst @@ -14,3 +14,4 @@ DAG Algorithms rustworkx.layers rustworkx.transitive_reduction rustworkx.topological_generations + rustworkx.transitive_closure_dag diff --git a/docs/source/api/algorithm_functions/traversal.rst b/docs/source/api/algorithm_functions/traversal.rst index c3a210c5e2..c1e4d107fa 100644 --- a/docs/source/api/algorithm_functions/traversal.rst +++ b/docs/source/api/algorithm_functions/traversal.rst @@ -22,3 +22,4 @@ Traversal rustworkx.visit.BFSVisitor rustworkx.visit.DijkstraVisitor rustworkx.TopologicalSorter + rustworkx.descendants_at_distance diff --git a/docs/source/api/pydigraph_api_functions.rst b/docs/source/api/pydigraph_api_functions.rst index 78153e9a31..cf20ad6d77 100644 --- a/docs/source/api/pydigraph_api_functions.rst +++ b/docs/source/api/pydigraph_api_functions.rst @@ -58,3 +58,4 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_dijkstra_search rustworkx.digraph_node_link_json rustworkx.digraph_longest_simple_path + rustworkx.graph_descendants_at_distance diff --git a/docs/source/api/pygraph_api_functions.rst b/docs/source/api/pygraph_api_functions.rst index 44bfdf2835..f7c0a328c1 100644 --- a/docs/source/api/pygraph_api_functions.rst +++ b/docs/source/api/pygraph_api_functions.rst @@ -58,3 +58,4 @@ typed API based on the data type. rustworkx.graph_dijkstra_search rustworkx.graph_node_link_json rustworkx.graph_longest_simple_path + rustworkx.digraph_descendants_at_distance diff --git a/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml new file mode 100644 index 0000000000..230527e24a --- /dev/null +++ b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml @@ -0,0 +1,16 @@ +--- +features: + - | + Added a new function ``descendants_at_distance`` to the rustworkx-core + crate under the ``traversal`` module + - | + Added a new function ``build_transitive_closure_dag`` to the rustworkx-core + crate under the ``traversal`` module. + - | + Added a new function, :func:`~.transitive_closure_dag`, which provides + an optimize method for computing the transitive closure of an input + DAG. + - | + Added a new function :func:`~.descendants_at_distance` which provides + a method to find the nodes at a fixed distance from a source in + a graph object. diff --git a/rustworkx-core/src/traversal/descendants.rs b/rustworkx-core/src/traversal/descendants.rs new file mode 100644 index 0000000000..67bf6bf606 --- /dev/null +++ b/rustworkx-core/src/traversal/descendants.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use hashbrown::HashSet; +use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable}; + +/// Returns all nodes at a fixed `distance` from `source` in `G`. +/// Args: +/// `graph`: +/// `source`: +/// `distance`: +pub fn descendants_at_distance(graph: G, source: G::NodeId, distance: usize) -> Vec +where + G: Visitable + IntoNeighborsDirected + NodeCount, + G::NodeId: std::cmp::Eq + std::hash::Hash, +{ + let mut current_layer: Vec = vec![source]; + let mut layers: usize = 0; + let mut visited: HashSet = HashSet::with_capacity(graph.node_count()); + visited.insert(source); + while !current_layer.is_empty() && layers < distance { + let mut next_layer: Vec = Vec::new(); + for node in current_layer { + for child in graph.neighbors_directed(node, petgraph::Outgoing) { + if !visited.contains(&child) { + visited.insert(child); + next_layer.push(child); + } + } + } + current_layer = next_layer; + layers += 1; + } + current_layer +} diff --git a/rustworkx-core/src/traversal/mod.rs b/rustworkx-core/src/traversal/mod.rs index c8213dfb36..2a7790bb1d 100644 --- a/rustworkx-core/src/traversal/mod.rs +++ b/rustworkx-core/src/traversal/mod.rs @@ -13,9 +13,11 @@ //! Module for graph traversal algorithms. mod bfs_visit; +mod descendants; mod dfs_edges; mod dfs_visit; mod dijkstra_visit; +mod transitive_closure; use petgraph::prelude::*; use petgraph::visit::GraphRef; @@ -25,9 +27,11 @@ use petgraph::visit::VisitMap; use petgraph::visit::Visitable; pub use bfs_visit::{breadth_first_search, BfsEvent}; +pub use descendants::descendants_at_distance; pub use dfs_edges::dfs_edges; pub use dfs_visit::{depth_first_search, DfsEvent}; pub use dijkstra_visit::{dijkstra_search, DijkstraEvent}; +pub use transitive_closure::build_transitive_closure_dag; /// Return if the expression is a break value, execute the provided statement /// if it is a prune value. diff --git a/rustworkx-core/src/traversal/transitive_closure.rs b/rustworkx-core/src/traversal/transitive_closure.rs new file mode 100644 index 0000000000..6d6fd196f2 --- /dev/null +++ b/rustworkx-core/src/traversal/transitive_closure.rs @@ -0,0 +1,83 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use petgraph::algo::{toposort, Cycle}; +use petgraph::data::Build; +use petgraph::visit::{ + GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, +}; + +use crate::traversal::descendants_at_distance; + +/// Build a transitive closure out of a given DAG +/// +/// This function will mutate a given DAG object (which is typically moved to +/// this function) into a transitive closure of the graph and then returned. +/// If you'd like to preserve the input graph pass a clone of the original graph. +/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)` +/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge +/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from +/// :math:`v` to :math:`w` in :math:`G`. This funciton provides an optimized +/// path for computing the the transitive closure of a DAG, if the input graph +/// contains cycles it will error. +/// +/// Arguments: +/// +/// - `graph`: A mutable graph object representing the DAG +/// - `topological_order`: An optional `Vec` of node identifiers representing +/// the topological order to traverse the DAG with. If not specified the +/// `petgraph::algo::toposort` function will be called to generate this +/// - `default_edge_weight`: A callable function that takes no arguments and +/// returns the `EdgeWeight` type object to use for each edge added to +/// `graph +/// +/// # Example +/// +/// ```rust +/// use rustworkx_core::traversal::build_transitive_closure_dag; +/// +/// let g = petgraph::graph::DiGraph::::from_edges(&[(0, 1, 0), (1, 2, 0), (2, 3, 0)]); +/// +/// let res = build_transitive_closure_dag(g, None, || -> i32 {0}); +/// let out_graph = res.unwrap(); +/// let out_edges: Vec<(usize, usize)> = out_graph +/// .edge_indices() +/// .map(|e| { +/// let endpoints = out_graph.edge_endpoints(e).unwrap(); +/// (endpoints.0.index(), endpoints.1.index()) +/// }) +/// .collect(); +/// assert_eq!(vec![(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)], out_edges) +/// ``` +pub fn build_transitive_closure_dag<'a, G, F>( + mut graph: G, + topological_order: Option>, + default_edge_weight: F, +) -> Result> +where + G: NodeCount + Build + Clone, + for<'b> &'b G: + GraphBase + Visitable + IntoNeighborsDirected + IntoNodeIdentifiers, + G::NodeId: std::cmp::Eq + std::hash::Hash, + F: Fn() -> G::EdgeWeight, +{ + let node_order: Vec = match topological_order { + Some(topo_order) => topo_order, + None => toposort(&graph, None)?, + }; + for node in node_order.into_iter().rev() { + for descendant in descendants_at_distance(&graph, node, 2) { + graph.add_edge(node, descendant, default_edge_weight()); + } + } + Ok(graph) +} diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 2943017fcc..be8e99cace 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -1875,6 +1875,30 @@ def longest_simple_path(graph): """ +@_rustworkx_dispatch +def descendants_at_distance(graph, source, distance): + """Returns all nodes at a fixed distance from ``source`` in ``graph`` + + :param graph: The graph to find the descendants in + :param int source: The node index to find the descendants from + :param int distance: The distance from ``source`` + + :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. + :rtype: NodeIndices + + For example:: + + import rustworkx as rx + + graph = rx.generators.path_graph(5) + res = rx.descendants_at_distance(graph, 2, 2) + print(res) + + will return: ``[0, 4]`` + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + @_rustworkx_dispatch def isolates(graph): """Return a list of isolates in a graph object @@ -2006,3 +2030,4 @@ def all_shortest_paths( """ raise TypeError("Invalid Input Type %s for graph" % type(graph)) +>>>>>>> origin/main diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 65df6afcc7..ed279f78ac 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -92,6 +92,8 @@ pub fn traversal_directions(reverse: bool) -> (petgraph::Direction, petgraph::Di } } +use rustworkx_core::traversal; + /// Find the longest path in a DAG /// /// :param PyDiGraph graph: The graph to find the longest path on. The input @@ -597,6 +599,48 @@ pub fn collect_bicolor_runs( Ok(block_list) } +/// Return the transitive closure of a directed acyclic graph +/// +/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)` +/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge +/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from +/// :math:`v` to :math:`w` in :math:`G`. +/// +/// :param PyDiGraph graph: The input DAG to compute the transitive closure of +/// :param list topological_order: An optional topological order for ``graph`` +/// which represents the order the graph will be traversed in computing +/// the transitive closure. If one is not provided (or it is explicitly +/// set to ``None``) a topological order will be computed by this function. +/// +/// :returns: The transitive closure of ``graph`` +/// :rtype: PyDiGraph +/// +/// :raises DAGHasCycle: If the input ``graph`` is not acyclic +#[pyfunction] +#[pyo3(text_signature = "(graph, / topological_order=None)")] +pub fn transitive_closure_dag( + py: Python, + graph: &digraph::PyDiGraph, + topological_order: Option>, +) -> PyResult { + let default_weight = || -> PyObject { py.None() }; + match traversal::build_transitive_closure_dag( + graph.graph.clone(), + topological_order.map(|order| order.into_iter().map(NodeIndex::new).collect()), + default_weight, + ) { + Ok(out_graph) => Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }), + Err(_err) => Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")), + } +} + /// Returns the transitive reduction of a directed acyclic graph /// /// The transitive reduction of :math:`G = (V,E)` is a graph :math:`G\prime = (V,E\prime)` @@ -612,7 +656,6 @@ pub fn collect_bicolor_runs( /// :rtype: Tuple[PyGraph, dict] /// /// :raises PyValueError: if ``graph`` is not a DAG - #[pyfunction] #[pyo3(text_signature = "(graph, /)")] pub fn transitive_reduction( diff --git a/src/lib.rs b/src/lib.rs index 79f183462f..fe7c29a960 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -614,6 +614,9 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(read_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; + m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?; + m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?; + m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?; m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?; m.add_wrapped(wrap_pyfunction!(parse_node_link_json))?; m.add_wrapped(wrap_pyfunction!(pagerank))?; diff --git a/src/traversal/mod.rs b/src/traversal/mod.rs index f6ce66a767..7900de054d 100644 --- a/src/traversal/mod.rs +++ b/src/traversal/mod.rs @@ -21,7 +21,7 @@ use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor}; use rustworkx_core::traversal::{ ancestors as core_ancestors, bfs_predecessors as core_bfs_predecessors, bfs_successors as core_bfs_successors, breadth_first_search, depth_first_search, - descendants as core_descendants, dfs_edges, dijkstra_search, + descendants as core_descendants, descendants_at_distance, dfs_edges, dijkstra_search, }; use super::{digraph, graph, iterators, CostFn}; @@ -773,3 +773,67 @@ pub fn graph_dijkstra_search( Ok(()) } + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[0, 4]`` +#[pyfunction] +pub fn graph_descendants_at_distance( + graph: graph::PyGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyDiGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.directed_path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[4]`` +#[pyfunction] +pub fn digraph_descendants_at_distance( + graph: digraph::PyDiGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} diff --git a/tests/digraph/test_transitive_closure.py b/tests/digraph/test_transitive_closure.py new file mode 100644 index 0000000000..cf707bf555 --- /dev/null +++ b/tests/digraph/test_transitive_closure.py @@ -0,0 +1,33 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx as rx + + +class TestTransitivity(unittest.TestCase): + def test_path_graph(self): + graph = rx.generators.directed_path_graph(4) + transitive_closure = rx.transitive_closure_dag(graph) + expected_edge_list = [(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)] + self.assertEqual(transitive_closure.edge_list(), expected_edge_list) + + def test_invalid_type(self): + with self.assertRaises(TypeError): + rx.transitive_closure_dag(rx.PyGraph()) + + def test_cycle_error(self): + graph = rx.PyDiGraph() + graph.extend_from_edge_list([(0, 1), (1, 0)]) + with self.assertRaises(rx.DAGHasCycle): + rx.transitive_closure_dag(graph)