Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def check_net_device(

if isinstance(net, nn.Identity):
return net
if str(next(net.parameters()).device) != device:
if str(next(net.parameters()).device) != str(device):
warn(
message or f"Network is not on the correct device. Moving it to {device}.",
stacklevel=2,
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def validate_theta_and_x(
assert theta.dtype == float32, "Type of parameters must be float32."
assert x.dtype == float32, "Type of simulator outputs must be float32."

if str(x.device) != data_device:
if str(x.device) != str(data_device):
warnings.warn(
f"Data x has device '{x.device}'. "
f"Moving x to the data_device '{data_device}'. "
Expand All @@ -724,7 +724,7 @@ def validate_theta_and_x(
)
x = x.to(data_device)

if str(theta.device) != data_device:
if str(theta.device) != str(data_device):
warnings.warn(
f"Parameters theta has device '{theta.device}'. "
f"Moving theta to the data_device '{data_device}'. "
Expand Down