From 7dc738105f24319e4c192bff5c595af5d1d6f014 Mon Sep 17 00:00:00 2001 From: Gavin Weiguang Ding Date: Wed, 19 Feb 2020 12:08:34 -0500 Subject: [PATCH] bpda cuda --- tests/test_bpda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_bpda.py b/tests/test_bpda.py index c0028a5..85100c4 100644 --- a/tests/test_bpda.py +++ b/tests/test_bpda.py @@ -121,7 +121,7 @@ def forward(self, x): z = bpda(x, y) z_ = z.detach().requires_grad_() - net = nn.Sequential(func, DummyNet()) + net = nn.Sequential(func, DummyNet().to(device)) with torch.enable_grad(): loss_ = net(z_).sum()