@@ -3,6 +3,7 @@ use petgraph::graph::{EdgeIndex, Graph as PetGraph, NodeIndex};
3
3
use petgraph:: prelude:: EdgeRef ;
4
4
use petgraph:: visit:: IntoNodeReferences ;
5
5
use petgraph:: { Directed , Undirected } ;
6
+ use sprs:: { CsMat , TriMat } ;
6
7
7
8
/// Trait for constructing graphs with specific edge types.
8
9
pub trait GraphConstructor < A , W > : petgraph:: EdgeType + Sized {
@@ -67,7 +68,7 @@ impl EdgeId {
67
68
}
68
69
}
69
70
70
- /// Base graph structure that wraps around `PetGraph` .
71
+ /// Base graph structure that wraps around a petgraph instance .
71
72
#[ derive( Debug , Clone ) ]
72
73
pub struct BaseGraph < A , W , Ty : GraphConstructor < A , W > > {
73
74
inner : PetGraph < A , W , Ty > ,
@@ -117,6 +118,7 @@ impl<A, W, Ty: GraphConstructor<A, W>> BaseGraph<A, W, Ty> {
117
118
self . inner . neighbors ( index) . map ( NodeId :: new)
118
119
}
119
120
121
+ /// Returns a reference to the attribute of a node.
120
122
pub fn node_attr ( & self , node : NodeId ) -> Option < & A > {
121
123
let index = NodeIndex :: new ( node. index ( ) ) ;
122
124
self . inner . node_weight ( index)
@@ -146,17 +148,124 @@ impl<A, W, Ty: GraphConstructor<A, W>> BaseGraph<A, W, Ty> {
146
148
} )
147
149
}
148
150
149
- /// Returns a reference to the inner `PetGraph`.
150
- pub fn inner ( & self ) -> & PetGraph < A , W , Ty > {
151
+ /// Returns a reference to the inner petgraph instance.
152
+ /// (Not exposed to the user as part of the public API.)
153
+ fn inner ( & self ) -> & PetGraph < A , W , Ty > {
151
154
& self . inner
152
155
}
153
156
154
- /// Returns a mutable reference to the inner `PetGraph`.
155
- pub fn inner_mut ( & mut self ) -> & mut PetGraph < A , W , Ty > {
157
+ /// Returns a mutable reference to the inner petgraph instance.
158
+ /// (Not exposed to the user as part of the public API.)
159
+ fn inner_mut ( & mut self ) -> & mut PetGraph < A , W , Ty > {
156
160
& mut self . inner
157
161
}
158
162
}
159
163
164
+ /// Dense matrix API using owned values.
165
+ impl < A , W , Ty : GraphConstructor < A , W > > BaseGraph < A , W , Ty >
166
+ where
167
+ W : Clone ,
168
+ {
169
+ /// Returns the adjacency matrix of the graph as a 2D vector.
170
+ ///
171
+ /// Each entry at `[i][j]` is an `Option<W>` which is `Some(w)` if an edge exists
172
+ /// from node `i` to node `j`, or `None` otherwise.
173
+ /// For undirected graphs, the matrix is symmetric.
174
+ pub fn to_adjacency_matrix ( & self ) -> Vec < Vec < Option < W > > > {
175
+ let n = self . node_count ( ) ;
176
+ let mut matrix = vec ! [ vec![ None ; n] ; n] ;
177
+ for edge in self . inner ( ) . edge_references ( ) {
178
+ let i = edge. source ( ) . index ( ) ;
179
+ let j = edge. target ( ) . index ( ) ;
180
+ matrix[ i] [ j] = Some ( edge. weight ( ) . clone ( ) ) ;
181
+ if !<Ty as GraphConstructor < A , W > >:: is_directed ( ) {
182
+ matrix[ j] [ i] = Some ( edge. weight ( ) . clone ( ) ) ;
183
+ }
184
+ }
185
+ matrix
186
+ }
187
+
188
+ /// Constructs a new graph from an adjacency matrix.
189
+ ///
190
+ /// The input is a slice of vectors, where each inner vector represents the outgoing edges
191
+ /// from a node. A value of `Some(w)` at position `[i][j]` indicates an edge from node `i` to node `j`
192
+ /// with weight `w`. For undirected graphs, only the upper triangle (including the diagonal)
193
+ /// of the matrix is considered.
194
+ ///
195
+ /// Node attributes are initialized using `A::default()`, so `A` must implement `Default`.
196
+ pub fn from_adjacency_matrix ( matrix : & [ Vec < Option < W > > ] ) -> Self
197
+ where
198
+ A : Default ,
199
+ {
200
+ let n = matrix. len ( ) ;
201
+ let mut graph = Self :: new ( ) ;
202
+ // Add n nodes with default attributes.
203
+ let nodes: Vec < NodeId > = ( 0 ..n) . map ( |_| graph. add_node ( A :: default ( ) ) ) . collect ( ) ;
204
+ // Insert edges based on the matrix.
205
+ for i in 0 ..n {
206
+ for j in 0 ..matrix[ i] . len ( ) {
207
+ if let Some ( weight) = & matrix[ i] [ j] {
208
+ if <Ty as GraphConstructor < A , W > >:: is_directed ( ) || i <= j {
209
+ graph. add_edge ( nodes[ i] , nodes[ j] , weight. clone ( ) ) ;
210
+ }
211
+ }
212
+ }
213
+ }
214
+ graph
215
+ }
216
+ }
217
+
218
+ /// Sparse matrix API using sprs for efficiency on large graphs.
219
+ /// The trait bound now includes Add so that duplicate entries can be combined.
220
+ impl < A , W , Ty : GraphConstructor < A , W > > BaseGraph < A , W , Ty >
221
+ where
222
+ W : Clone + std:: ops:: Add < Output = W > ,
223
+ {
224
+ /// Returns the sparse adjacency matrix of the graph as a CsMat in CSR format.
225
+ ///
226
+ /// Only existing edges are stored. For undirected graphs, both (i,j) and (j,i) are inserted,
227
+ /// except for self-loops.
228
+ pub fn to_sparse_adjacency_matrix ( & self ) -> CsMat < W > {
229
+ let n = self . node_count ( ) ;
230
+ let mut triplet = TriMat :: new ( ( n, n) ) ;
231
+ for edge in self . inner ( ) . edge_references ( ) {
232
+ let i = edge. source ( ) . index ( ) ;
233
+ let j = edge. target ( ) . index ( ) ;
234
+ triplet. add_triplet ( i, j, edge. weight ( ) . clone ( ) ) ;
235
+ if !<Ty as GraphConstructor < A , W > >:: is_directed ( ) && i != j {
236
+ triplet. add_triplet ( j, i, edge. weight ( ) . clone ( ) ) ;
237
+ }
238
+ }
239
+ // Convert the triplet matrix into CSR format.
240
+ triplet. to_csr ( )
241
+ }
242
+
243
+ /// Constructs a new graph from a sparse adjacency matrix.
244
+ ///
245
+ /// The input is a CsMat (typically in CSR format) where nonzero entries represent edges with their weights.
246
+ /// For undirected graphs, only one of the symmetric entries is used (edges with i <= j).
247
+ ///
248
+ /// Node attributes are initialized using `A::default()`, so `A` must implement `Default`.
249
+ pub fn from_sparse_adjacency_matrix ( sparse : & CsMat < W > ) -> Self
250
+ where
251
+ A : Default ,
252
+ {
253
+ let n = sparse. rows ( ) ;
254
+ let mut graph = Self :: new ( ) ;
255
+ let nodes: Vec < NodeId > = ( 0 ..n) . map ( |_| graph. add_node ( A :: default ( ) ) ) . collect ( ) ;
256
+ // Iterate over the outer (row) indices.
257
+ for ( i, row) in sparse. outer_iterator ( ) . enumerate ( ) {
258
+ // row.indices() gives column indices and row.data() gives the corresponding weights.
259
+ for ( & j, weight) in row. indices ( ) . iter ( ) . zip ( row. data ( ) . iter ( ) ) {
260
+ if <Ty as GraphConstructor < A , W > >:: is_directed ( ) || i <= j {
261
+ graph. add_edge ( nodes[ i] , nodes[ j] , weight. clone ( ) ) ;
262
+ }
263
+ }
264
+ }
265
+ graph
266
+ }
267
+ }
268
+
160
269
/// Type alias for a directed graph.
161
270
pub type Digraph < A , W > = BaseGraph < A , W , Directed > ;
162
271
/// Type alias for an undirected graph.
0 commit comments