Skip to content

Commit f6c6080

Browse files
committed
add test attention.py
1 parent a2df15d commit f6c6080

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

openrl/configs/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def create_config_parser():
618618
)
619619
parser.add_argument(
620620
"--use_average_pool",
621-
action="store_false",
621+
type=bool,
622622
default=True,
623623
help="by default True, use average pooling for attn model.",
624624
)

openrl/modules/networks/utils/attention.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,13 @@ def forward(self, x, self_idx=-1):
234234
K = self.split_shape[i][0]
235235
L = self.split_shape[i][1]
236236
for j in range(K):
237-
torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
238-
exec("x1.append(self.fc_{}(temp))".format(i))
239-
x[self_idx]
240-
exec("x1.append(self.fc_{}(temp))".format(N - 1))
237+
# torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
238+
# exec("x1.append(self.fc_{}(temp))".format(i))
239+
temp = torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
240+
x1.append(getattr(self, "fc_" + str(i))(temp))
241+
x1.append(getattr(self, "fc_" + str(N - 1))(self_x))
242+
# x[self_idx]
243+
# exec("x1.append(self.fc_{}(temp))".format(N - 1))
241244

242245
out = torch.stack(x1, 1)
243246

@@ -278,8 +281,10 @@ def forward(self, x, self_idx=None):
278281
K = self.split_shape[i][0]
279282
L = self.split_shape[i][1]
280283
for j in range(K):
281-
x[i][:, (L * j) : (L * j + L)]
282-
exec("x1.append(self.fc_{}(temp))".format(i))
284+
# x[i][:, (L * j) : (L * j + L)]
285+
# exec("x1.append(self.fc_{}(temp))".format(i))
286+
temp = x[i][:, (L * j) : (L * j + L)]
287+
x1.append(getattr(self, "fc_" + str(i))(temp))
283288

284289
out = torch.stack(x1, 1)
285290

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
import torch
24+
25+
from openrl.configs.config import create_config_parser
26+
from openrl.modules.networks.utils.attention import Encoder
27+
28+
29+
@pytest.fixture(
30+
scope="module", params=["--use_average_pool True", "--use_average_pool False"]
31+
)
32+
def config(request):
33+
cfg_parser = create_config_parser()
34+
cfg = cfg_parser.parse_args(request.param.split())
35+
return cfg
36+
37+
38+
@pytest.mark.unittest
39+
def test_attention(config):
40+
for cat_self in [False, True]:
41+
net = Encoder(cfg=config, split_shape=[[1, 1], [1, 1]], cat_self=cat_self)
42+
net(torch.zeros((1, 1)))
43+
44+
45+
if __name__ == "__main__":
46+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)