Skip to content

Commit 5a8d58e

Browse files
committed
small change to neurallp
1 parent 7b17e4d commit 5a8d58e

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

torchdrug/models/gin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
2929
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
3030
"""
3131

32-
def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
32+
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
3333
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
3434
readout="sum"):
3535
super(GraphIsomorphismNetwork, self).__init__()

torchdrug/models/neurallp.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
1717
https://papers.nips.cc/paper/2017/file/0e55666a4ad822e0e34299df3591d979-Paper.pdf
1818
1919
Parameters:
20-
num_entity (int): number of entities
2120
num_relation (int): number of relations
2221
hidden_dim (int): dimension of hidden units in LSTM
2322
num_step (int): number of recurrent steps
@@ -26,17 +25,15 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
2625

2726
eps = 1e-10
2827

29-
def __init__(self, num_entity, num_relation, hidden_dim, num_step, num_lstm_layer=1):
28+
def __init__(self, num_relation, hidden_dim, num_step, num_lstm_layer=1):
3029
super(NeuralLogicProgramming, self).__init__()
3130

3231
num_relation = int(num_relation)
33-
self.num_entity = num_entity
3432
self.num_relation = num_relation
3533
self.num_step = num_step
3634

3735
self.query = nn.Embedding(num_relation * 2 + 1, hidden_dim)
3836
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_lstm_layer)
39-
self.key_linear = nn.Linear(hidden_dim, hidden_dim)
4037
self.weight_linear = nn.Linear(hidden_dim, num_relation * 2)
4138
self.linear = nn.Linear(1, 1)
4239

@@ -56,7 +53,7 @@ def get_t_output(self, graph, h_index, r_index):
5653
query = self.query(q_index)
5754

5855
hidden, hx = self.lstm(query)
59-
memory = functional.one_hot(h_index, self.num_entity).unsqueeze(0)
56+
memory = functional.one_hot(h_index, graph.num_entity).unsqueeze(0)
6057

6158
for i in range(self.num_step):
6259
key = hidden[i]

0 commit comments

Comments
 (0)