Skip to content

Commit c40b642

Browse files
committed
fix: making sure elements in pytorch files are tensors
1 parent 83961ea commit c40b642

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/core/handlers/pytorch/inspect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import argparse
5+
import numpy as np
56

67

78
def main():
@@ -51,6 +52,13 @@ def main():
5152
model = model["model"]
5253

5354
for tensor_name, tensor in model.items():
55+
# make sure it's a tensor
56+
if not isinstance(tensor, torch.Tensor):
57+
try:
58+
tensor = torch.tensor(tensor)
59+
except:
60+
continue
61+
5462
inspection["data_size"] += tensor.shape.numel() * tensor.element_size()
5563

5664
shape = list(tensor.shape)

0 commit comments

Comments
 (0)