Skip to content

Commit f8d2de6

Browse files
committed
add node/edge/graph reference & change node/edge to atom/bond
1 parent 20f8417 commit f8d2de6

File tree

16 files changed

+887
-278
lines changed

16 files changed

+887
-278
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ Molecules are also supported in TorchDrug. You can get the desired molecule
9595
properties without any domain knowledge.
9696

9797
```python
98-
mol = data.Molecule.from_smiles("CCOC(=O)N", node_feature="default", edge_feature="default")
98+
mol = data.Molecule.from_smiles("CCOC(=O)N", atom_feature="default", bond_feature="default")
9999
print(mol.node_feature)
100100
print(mol.atom_type)
101101
print(mol.to_scaffold())

doc/source/notes/graph.rst

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Create a Graph
99

1010
To begin with, let's create a graph.
1111

12-
.. code-block:: python
12+
.. code:: python
1313
1414
import torch
1515
from torchdrug import data
@@ -25,15 +25,15 @@ This will plot a ring graph like the following.
2525
:width: 33%
2626

2727
Internally, the graph is stored as a sparse edge list to save memory footprint. For
28-
an intuitive comparison, a `scale-free graph`_ mayr have 1 million nodes and 10 million
28+
an intuitive comparison, a `scale-free graph`_ may have 1 million nodes and 10 million
2929
edges. The dense version takes about 4TB, while the sparse version only requires 120MB.
3030

3131
.. _scale-free graph:
3232
https://en.wikipedia.org/wiki/Scale-free_network
3333

3434
Here are some commonly used properties of the graph.
3535

36-
.. code-block:: python
36+
.. code:: python
3737
3838
print(graph.num_node)
3939
print(graph.num_edge)
@@ -45,7 +45,7 @@ molecules have bond types like ``single bound``, while knowledge graphs have rel
4545
like ``consists of``. To construct such a relational graph, we can pass the edge type
4646
as a third variable in the edge list.
4747

48-
.. code-block:: python
48+
.. code:: python
4949
5050
triplet_list = [[0, 1, 0], [1, 2, 1], [2, 3, 0], [3, 4, 1], [4, 5, 0], [5, 0, 1]]
5151
graph = data.Graph(triplet_list, num_node=6, num_relation=2)
@@ -62,7 +62,7 @@ corresponds to an edge from node :math:`i` to node :math:`j`. The relational gra
6262
uses a 3D adjacency matrix :math:`A`, where non-zero :math:`A_{i,j,k}` denotes an
6363
edge from node :math:`i` to node :math:`j` with edge type :math:`k`.
6464

65-
.. code-block:: python
65+
.. code:: python
6666
6767
adjacency = torch.zeros(6, 6)
6868
adjacency[edge_list] = 1
@@ -78,7 +78,7 @@ For example, the following code creates a benzene molecule.
7878
.. _SMILES:
7979
https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system
8080

81-
.. code-block:: python
81+
.. code:: python
8282
8383
mol = data.Molecule.from_smiles("C1=CC=CC=C1")
8484
mol.visualize()
@@ -90,7 +90,7 @@ For example, the following code creates a benzene molecule.
9090
Once the graph is created, we can transfer it between CPU and GPUs, just like
9191
:class:`torch.Tensor`.
9292

93-
.. code-block:: python
93+
.. code:: python
9494
9595
graph = graph.cuda()
9696
print(graph.device)
@@ -109,7 +109,7 @@ during any graph operation.
109109

110110
Here we specify some features during the construction of the molecule graph.
111111

112-
.. code-block:: python
112+
.. code:: python
113113
114114
mol = data.Molecule.from_smiles("C1=CC=CC=C1", node_feature="default",
115115
edge_feature="default", graph_feature="ecfp")
@@ -122,14 +122,15 @@ We may also want to define our own attributes. This only requires to wrap the
122122
assignment lines with a context manager. The following example defines edge importance
123123
as the reciprocal of node degrees.
124124

125-
.. code-block:: python
125+
.. code:: python
126126
127127
node_in, node_out = mol.edge_list.t()[:2]
128128
with mol.edge():
129129
mol.edge_importance = 1 / graph.degree_in[node_in] + 1 / graph.degree_out[node_out]
130130
131131
We can use ``mol.node()`` and ``mol.graph()`` for node- and graph-level attributes
132-
respectively.
132+
respectively. Attributes may also be a reference to node/edge/graph indexes. See
133+
:doc:`reference` for more details.
133134

134135
Note in order to support batching and masking, attributes should always have the same
135136
length as their corresponding components. This means the size of the first dimension of
@@ -142,7 +143,7 @@ Modern deep learning frameworks employs batched operations to accelerate computa
142143
In TorchDrug, we can easily batch same kind of graphs with **arbitary sizes**. Here
143144
is an example of creating a batch of 4 graphs.
144145

145-
.. code-block:: python
146+
.. code:: python
146147
147148
graphs = [graph, graph, graph, graph]
148149
batch = data.Graph.pack(graphs)
@@ -170,7 +171,7 @@ where :math:`A_i` is the adjacency of :math:`i`-th graph.
170171
To get a single graph from the batch, use the conventional index or
171172
:meth:`PackedGraph.unpack <torchdrug.data.PackedGraph.unpack>`.
172173

173-
.. code-block:: python
174+
.. code:: python
174175
175176
graph = batch[1]
176177
graphs = batch.unpack()
@@ -186,7 +187,7 @@ Subgraph and Masking
186187
The graph data structure also provides a bunch of slicing operations to create subgraphs
187188
or masked graphs in a sparse manner. Some typical operations include
188189

189-
.. code-block:: python
190+
.. code:: python
190191
191192
g1 = graph.subgraph([1, 2, 3, 4])
192193
g1.visualize()
@@ -220,7 +221,7 @@ isolated nodes.
220221
The same operations can also be applied to batches. In this case, we need to convert
221222
the index of a single graph into the index in a batch.
222223

223-
.. code-block:: python
224+
.. code:: python
224225
225226
graph_ids = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
226227
node_ids = torch.tensor([1, 2, 3, 4, 0, 1, 2, 3, 4, 5])
@@ -232,7 +233,7 @@ the index of a single graph into the index in a batch.
232233

233234
We can also pick a subset of graphs in a batch.
234235

235-
.. code-block:: python
236+
.. code:: python
236237
237238
batch = batch[[0, 1]]
238239
batch.visualize()

doc/source/notes/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ Notes
77
variadic
88
layer
99
model
10+
reference

doc/source/notes/layer.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ representations as a graph representation. First, we readout the mean of node
105105
representations. Second, we broadcast the mean representation to each node to compute
106106
the difference. Finally, we readout the mean of the squared difference as the variance.
107107

108-
.. code-block:: python
108+
.. code:: python
109109
110110
from torch import nn
111111
from torch_scatter import scatter_mean

doc/source/notes/model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ distributed, module-centric manner.
6969
We compute the variational regularization loss, and add it to the global loss and the
7070
global metric.
7171

72-
.. code-block::
72+
.. code::
7373
7474
def reparameterize(self, mu, log_sigma):
7575
if self.training:

doc/source/notes/reference.rst

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
Deal with References
2+
====================
3+
4+
As we show in :doc:`graph`, custom graph attributes will be automatically processed
5+
in any graph operation. However, some attributes may refer to node/edge/graph indexes,
6+
and their values need to be modified when the indexes change. TorchDrug provides a
7+
mechanism to deal with such cases.
8+
9+
Inverse Edge Index
10+
------------------
11+
12+
A typical example of reference is a mapping from each edge to its inverse edge.
13+
We first prepare an undirected graph with the indexes of inverse edges.
14+
15+
.. code:: python
16+
17+
import torch
18+
from torchdrug import data
19+
20+
edge_list = [[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]]
21+
inv_edge_index = [1, 0, 3, 2, 5, 4]
22+
graph = data.Graph(edge_list, num_node=3)
23+
24+
.. image:: ../../../asset/graph/inverse_edge.png
25+
:align: center
26+
:width: 33%
27+
28+
If we assign the indexes as an edge attribute and apply an edge mask operation,
29+
the result is not desired. The edges are masked out correctly, but the values of
30+
inverse indexes are wrong.
31+
32+
.. code:: python
33+
with graph.edge():
34+
graph.inv_edge_index = torch.tensor(inv_edge_index)
35+
g1 = graph.edge_mask([0, 2, 3])
36+
37+
.. image:: ../../../asset/graph/wrong_reference.png
38+
:align: center
39+
:width: 33%
40+
41+
Instead, we need to explicitly tell TorchDrug that the attribute ``graph.inv_edge_index``
42+
is a reference to edge indexes. This is done by an additional context manager
43+
``graph.edge_reference()``. Now we get the correct inverse indexes. Note that missing
44+
references will be set to ``-1``. In this case, the inverse index of ``0`` is ``-1``,
45+
since the corresponding inverse edge has been masked out.
46+
47+
.. code:: python
48+
49+
with graph.edge(), graph.edge_reference():
50+
graph.inv_edge_index = torch.tensor(inv_edge_index)
51+
g2 = graph.edge_mask([0, 2, 3])
52+
53+
.. image:: ../../../asset/graph/correct_reference.png
54+
:align: center
55+
:width: 33%
56+
57+
We can use ``graph.node_reference()`` and ``graph.graph_reference()`` for references
58+
to nodes and graphs respectively.
59+
60+
Use Cases in Proteins
61+
---------------------
62+
63+
In :class:`data.Protein`, the mapping ``atom2residue`` is implemented as
64+
references. The intuition is that references enable flexible indexing on either atoms
65+
or residues, while maintaining the correspondence between two views.
66+
67+
The following example shows how to track a specific residue with ``atom2residue`` in
68+
the atom view. For a protein, we first create a mask for atoms in a glutamine (GLN).
69+
70+
.. code:: python
71+
72+
protein = data.Protein.from_sequence("KALKQMLDMG")
73+
is_glutamine = protein.residue_type[protein.atom2residue] == protein.residue2id["GLN"]
74+
with protein.node():
75+
protein.is_glutamine = is_glutamine
76+
77+
We then apply a mask to the protein residue sequence. In the output protein,
78+
``atom2residue`` is able to map the masked atoms back to the glutamine residue.
79+
80+
.. code:: python
81+
82+
p1 = protein[3:6]
83+
residue_type = p1.residue_type[p1.atom2residue[p1.is_glutamine]]
84+
print([p1.id2residue[r] for r in residue_type.tolist()])
85+
86+
.. code:: bash
87+
88+
['GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN']

test/data/test_graph.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,64 @@ def test_match(self):
326326
index_results = index_result.split(num_match_result.tolist())
327327
match = ((graph.edge_list.unsqueeze(0) == edge.unsqueeze(1)) | (edge.unsqueeze(1) == -1)).all(dim=-1)
328328
query_index, index_truth = match.nonzero().t()
329-
num_match_truth = torch.bincount(query_index, minlength=len(edge))
329+
num_match_truth = query_index.bincount(minlength=len(edge))
330330
index_truths = index_truth.split(num_match_truth.tolist())
331331
self.assertTrue(torch.equal(num_match_result, num_match_truth), "Incorrect edge match")
332332
for index_result, index_truth in zip(index_results, index_truths):
333333
self.assertTrue(torch.equal(index_result.sort()[0], index_truth.sort()[0]), "Incorrect edge match")
334334

335+
def test_reference(self):
336+
node_out = torch.arange(1, self.num_node)
337+
node_in = (node_out - 1) // 2
338+
edge_list = torch.stack([node_in, node_out], dim=-1)
339+
tree = data.Graph(edge_list, num_node=self.num_node)
340+
with tree.node(), tree.node_reference():
341+
tree.dad = (torch.arange(self.num_node) - 1) // 2
342+
343+
mask = torch.arange(1, self.num_node)
344+
graph = tree.subgraph(mask)
345+
degree_in_result = graph.dad[graph.dad != -1].bincount(minlength=graph.num_node)
346+
is_root_result = graph.dad == -1
347+
node_in, node_out = graph.edge_list.t()
348+
degree_in_truth = node_in.bincount(minlength=graph.num_node)
349+
is_root_truth = node_out.bincount(minlength=graph.num_node) == 0
350+
self.assertTrue(torch.equal(degree_in_result, degree_in_truth), "Incorrect node reference")
351+
self.assertTrue(torch.equal(is_root_result, is_root_truth), "Incorrect node reference")
352+
353+
packed_graph = tree.repeat(4)
354+
packed_graph2 = data.Graph.pack([tree] * 4)
355+
self.assert_equal(packed_graph, packed_graph2, "node reference")
356+
357+
# special case: 0 repetition
358+
repeats = [2, 0, 1, 2]
359+
trees = []
360+
for start in range(4):
361+
index = torch.arange(start, self.num_node)
362+
trees.append(tree.subgraph(index))
363+
packed_graph = data.Graph.pack(trees)
364+
repeat_graph = packed_graph.repeat_interleave(repeats)
365+
true_graphs = []
366+
for i, tree in zip(repeats, trees):
367+
true_graphs += [tree] * i
368+
true_graph = data.Graph.pack(true_graphs)
369+
self.assert_equal(repeat_graph, true_graph, "node reference")
370+
371+
def test_line_graph(self):
372+
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node, edge_feature=self.edge_feature)
373+
line_graph = graph.line_graph()
374+
adj_result = line_graph.adjacency.to_dense()
375+
feat_result = line_graph.node_feature
376+
edge_index = torch.arange(graph.num_edge)
377+
node_in, node_out = graph.edge_list.t()
378+
edge2node_out = torch.zeros(graph.num_edge, graph.num_node)
379+
node_in2edge = torch.zeros(graph.num_node, graph.num_edge)
380+
edge2node_out[edge_index, node_out] = 1
381+
node_in2edge[node_in, edge_index] = 1
382+
adj_truth = edge2node_out @ node_in2edge
383+
feat_truth = graph.edge_feature
384+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect line graph")
385+
self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect line graph")
386+
335387

336388
if __name__ == "__main__":
337389
unittest.main()

test/data/test_molecule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_smiles(self):
3333
self.assertTrue((mols.num_edges == 0).all(), "Incorrect SMILES side case")
3434

3535
def test_feature(self):
36-
mol = data.Molecule.from_smiles(self.smiles, graph_feature="ecfp")
36+
mol = data.Molecule.from_smiles(self.smiles, mol_feature="ecfp")
3737
self.assertTrue((mol.graph_feature > 0).any(), "Incorrect ECFP feature")
3838

3939

test/layers/test_variadic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ def test_topk(self):
4545
self.assertTrue(torch.equal(result_value, truth_value), "Incorrect variadic topk")
4646
self.assertTrue(torch.equal(result_index, truth_index), "Incorrect variadic topk")
4747

48+
for _ in range(10):
49+
k = torch.randint(self.size.min(), self.size.max(), (self.num_graph,))
50+
result_value, result_index = functional.variadic_topk(self.input, self.size, k)
51+
_truth_value, _truth_index = self.padded.topk(self.size.max(), dim=1)
52+
truth_value, truth_index = [], []
53+
for i, size in enumerate(self.size):
54+
truth_value.append(_truth_value[i, :k[i]])
55+
truth_index.append(_truth_index[i, :k[i]])
56+
for j in range(size, k[i].item()):
57+
truth_value[i][j] = truth_value[i][j-1]
58+
truth_index[i][j] = truth_index[i][j-1]
59+
truth_value = torch.cat(truth_value, dim=0)
60+
truth_index = torch.cat(truth_index, dim=0)
61+
self.assertTrue(torch.equal(result_value, truth_value), "Incorrect variadic topk")
62+
self.assertTrue(torch.equal(result_index, truth_index), "Incorrect variadic topk")
4863

4964
if __name__ == "__main__":
5065
unittest.main()

0 commit comments

Comments
 (0)