Skip to content

Commit f5d9689

Browse files
authored
fix: RuntimeError in GraphIsomorphismNetwork #92 (#93)
1 parent 26f15f6 commit f5d9689

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

torchdrug/models/chebnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
3535
if not isinstance(hidden_dims, Sequence):
3636
hidden_dims = [hidden_dims]
3737
self.input_dim = input_dim
38-
self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1)
38+
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
3939
self.dims = [input_dim] + list(hidden_dims)
4040
self.short_cut = short_cut
4141
self.concat_hidden = concat_hidden

torchdrug/models/gat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
3535
if not isinstance(hidden_dims, Sequence):
3636
hidden_dims = [hidden_dims]
3737
self.input_dim = input_dim
38-
self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1)
38+
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
3939
self.dims = [input_dim] + list(hidden_dims)
4040
self.short_cut = short_cut
4141
self.concat_hidden = concat_hidden

torchdrug/models/gcn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
3333
if not isinstance(hidden_dims, Sequence):
3434
hidden_dims = [hidden_dims]
3535
self.input_dim = input_dim
36-
self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1)
36+
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
3737
self.dims = [input_dim] + list(hidden_dims)
3838
self.short_cut = short_cut
3939
self.concat_hidden = concat_hidden

torchdrug/models/gin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
3737
if not isinstance(hidden_dims, Sequence):
3838
hidden_dims = [hidden_dims]
3939
self.input_dim = input_dim
40-
self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1)
40+
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
4141
self.dims = [input_dim] + list(hidden_dims)
4242
self.short_cut = short_cut
4343
self.concat_hidden = concat_hidden

0 commit comments

Comments
 (0)