Skip to content

Commit ab11dc6

Browse files
authored
Merge branch 'main' into ruff-format
2 parents 8f086f7 + 745e87b commit ab11dc6

File tree

6 files changed

+103
-28
lines changed

6 files changed

+103
-28
lines changed

rustworkx-core/src/traversal/dfs_edges.rs

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
use std::hash::Hash;
1414

15-
use hashbrown::{HashMap, HashSet};
15+
use hashbrown::HashSet;
1616

1717
use petgraph::visit::{
1818
EdgeCount, IntoNeighbors, IntoNodeIdentifiers, NodeCount, NodeIndexable, Visitable,
@@ -75,7 +75,10 @@ where
7575
let mut out_vec: Vec<(usize, usize)> = if source.is_some() {
7676
Vec::new()
7777
} else {
78-
Vec::with_capacity(core::cmp::min(graph.node_count() - 1, graph.edge_count()))
78+
Vec::with_capacity(core::cmp::min(
79+
graph.node_count().saturating_sub(1),
80+
graph.edge_count(),
81+
))
7982
};
8083
for start in nodes {
8184
if visited.contains(&start) {
@@ -84,32 +87,25 @@ where
8487
visited.insert(start);
8588
let mut children: Vec<G::NodeId> = graph.neighbors(start).collect();
8689
children.reverse();
87-
let mut stack: Vec<(G::NodeId, Vec<G::NodeId>)> = vec![(start, children)];
88-
// Used to track the last position in children vec across iterations
89-
let mut index_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(node_count);
90-
index_map.insert(start, 0);
91-
while !stack.is_empty() {
92-
let temp_parent = stack.last().unwrap();
93-
let parent = temp_parent.0;
94-
let children = temp_parent.1.clone();
95-
let count = *index_map.get(&parent).unwrap();
90+
// Stack stores (node, children, next_child_index)
91+
let mut stack: Vec<(G::NodeId, Vec<G::NodeId>, usize)> = vec![(start, children, 0)];
92+
while let Some((parent, children, idx)) = stack.last_mut() {
9693
let mut found = false;
97-
let mut index = count;
98-
for child in &children[index..] {
99-
index += 1;
100-
if !visited.contains(child) {
101-
out_vec.push((graph.to_index(parent), graph.to_index(*child)));
102-
visited.insert(*child);
103-
let mut grandchildren: Vec<G::NodeId> = graph.neighbors(*child).collect();
94+
while *idx < children.len() {
95+
let child = children[*idx];
96+
*idx += 1;
97+
if !visited.contains(&child) {
98+
let parent_id = *parent;
99+
out_vec.push((graph.to_index(parent_id), graph.to_index(child)));
100+
visited.insert(child);
101+
let mut grandchildren: Vec<G::NodeId> = graph.neighbors(child).collect();
104102
grandchildren.reverse();
105-
stack.push((*child, grandchildren));
106-
index_map.insert(*child, 0);
107-
*index_map.get_mut(&parent).unwrap() = index;
103+
stack.push((child, grandchildren, 0));
108104
found = true;
109105
break;
110106
}
111107
}
112-
if !found || children.is_empty() {
108+
if !found {
113109
stack.pop();
114110
}
115111
}

src/connectivity/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,10 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
618618
}
619619

620620
let mut temp_graph = DiGraph::new();
621-
let mut node_map = Vec::new();
621+
let mut node_map = vec![NodeIndex::end(); graph.graph.node_bound()];
622622

623-
for _node in graph.graph.node_indices() {
624-
node_map.push(temp_graph.add_node(()));
623+
for node in graph.graph.node_indices() {
624+
node_map[node.index()] = temp_graph.add_node(());
625625
}
626626

627627
for edge in graph.graph.edge_indices() {

tests/digraph/test_dfs_edges.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,37 @@ def test_digraph_dfs_edges(self):
2929
edges = rustworkx.digraph_dfs_edges(graph, 0)
3030
expected = [(0, 1), (1, 2), (2, 4), (1, 3)]
3131
self.assertEqual(expected, edges)
32+
33+
def test_digraph_dfs_edges_empty(self):
34+
graph = rustworkx.PyDiGraph()
35+
edges = rustworkx.digraph_dfs_edges(graph)
36+
self.assertEqual([], edges)
37+
38+
def test_digraph_dfs_edges_single_node(self):
39+
graph = rustworkx.generators.directed_empty_graph(1)
40+
edges = rustworkx.digraph_dfs_edges(graph, 0)
41+
self.assertEqual([], edges)
42+
43+
def test_digraph_dfs_edges_node_gaps(self):
44+
graph = rustworkx.PyDiGraph()
45+
graph.add_nodes_from(range(5))
46+
graph.add_edge(0, 2, None)
47+
graph.add_edge(2, 4, None)
48+
graph.remove_node(1)
49+
graph.remove_node(3)
50+
edges = rustworkx.digraph_dfs_edges(graph, 0)
51+
self.assertEqual([(0, 2), (2, 4)], edges)
52+
53+
def test_digraph_dfs_edges_star(self):
54+
graph = rustworkx.generators.directed_star_graph(101)
55+
hub = 0
56+
spokes = list(range(1, 101))
57+
edges = rustworkx.digraph_dfs_edges(graph, hub)
58+
# Should visit all spokes exactly once
59+
self.assertEqual(len(edges), 100)
60+
# All edges should originate from hub
61+
for src, _ in edges:
62+
self.assertEqual(src, hub)
63+
# All spokes should be visited
64+
visited = {tgt for _, tgt in edges}
65+
self.assertEqual(visited, set(spokes))

tests/digraph/test_is_semi_connected.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,14 @@ def test_is_semi_connected_directed_star_graph(self):
7474
def test_is_semi_connected_directed_grid_graph(self):
7575
graph = rustworkx.generators.directed_grid_graph(10, 10)
7676
self.assertEqual(rustworkx.is_semi_connected(graph), naive_semi_connected(graph))
77+
78+
def test_is_semi_connected_with_node_gaps(self):
79+
graph = rustworkx.PyDiGraph()
80+
graph.add_nodes_from(list(range(5)))
81+
graph.remove_node(1)
82+
# Remaining nodes 0, 2, 3, 4 form a path
83+
graph.add_edge(0, 2, None)
84+
graph.add_edge(2, 3, None)
85+
graph.add_edge(3, 4, None)
86+
87+
self.assertEqual(rustworkx.is_semi_connected(graph), naive_semi_connected(graph))

tests/graph/test_dfs_edges.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,37 @@ def test_graph_disconnected_dfs_edges(self):
2929
edges = rustworkx.graph_dfs_edges(graph)
3030
expected = [(0, 1), (2, 3)]
3131
self.assertEqual(expected, edges)
32+
33+
def test_graph_dfs_edges_empty(self):
34+
graph = rustworkx.PyGraph()
35+
edges = rustworkx.graph_dfs_edges(graph)
36+
self.assertEqual([], edges)
37+
38+
def test_graph_dfs_edges_single_node(self):
39+
graph = rustworkx.generators.empty_graph(1)
40+
edges = rustworkx.graph_dfs_edges(graph, 0)
41+
self.assertEqual([], edges)
42+
43+
def test_graph_dfs_edges_node_gaps(self):
44+
graph = rustworkx.PyGraph()
45+
graph.add_nodes_from(range(5))
46+
graph.add_edge(0, 2, None)
47+
graph.add_edge(2, 4, None)
48+
graph.remove_node(1)
49+
graph.remove_node(3)
50+
edges = rustworkx.graph_dfs_edges(graph, 0)
51+
self.assertEqual([(0, 2), (2, 4)], edges)
52+
53+
def test_graph_dfs_edges_star(self):
54+
graph = rustworkx.generators.star_graph(101)
55+
hub = 0
56+
spokes = list(range(1, 101))
57+
edges = rustworkx.graph_dfs_edges(graph, hub)
58+
# Should visit all spokes exactly once
59+
self.assertEqual(len(edges), 100)
60+
# All edges should originate from hub
61+
for src, tgt in edges:
62+
self.assertEqual(src, hub)
63+
# All spokes should be visited
64+
visited = {tgt for _, tgt in edges}
65+
self.assertEqual(visited, set(spokes))

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)