-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Description
Can't use torch.jit.trace + torch.jit.script.
Error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Reproduce:
import torch
import torch.nn as nn
from fastervit import create_model
class FasterVit(nn.Module):
def __init__(
self,
encoder_name="faster_vit_0_any_res",
classes: int = 1,
window_size=[7, 7, 12, 6],
height = 512,
width = 512,
dim = 64,
pretrained=True,
**kwargs,
) -> None:
super().__init__()
self.model = create_model(
encoder_name,
resolution=[height, width],
window_size=window_size,
ct_size=2,
dim=dim,
pretrained=pretrained,
)
self.model.head = nn.Linear(self.model.head.in_features, classes)
def forward(self, x):
return self.model(x)
model = FasterVit()
class MyModule(nn.Module):
def __init__(
self,
model,
classes,
# thresholds,
height,
width,
):
super().__init__()
self.classes = classes
self.model = model.eval()
self.classes = classes
self.height = height
self.width = width
self.model = torch.jit.trace(self.model, torch.rand(1, 3, height, width))
def forward(self, input):
return self.model(input)
my_script_module = torch.jit.script(
MyModule(
model.to("cpu"),
classes=["1", "2"],
height=512,
width=512,
)
)
device = "cuda:0"
my_script_module.to(device)
example_input = torch.rand((1, 3, 512, 512)).to(device)
my_script_module(example_input)
Metadata
Metadata
Assignees
Labels
No labels