-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
Hello, I very appreciate your work and try to train it on my own datas. But I get a confusion in Dataset process as follow:
The model need a 4D tensor for input. After the ''torch.cat(_ret["ref_imgs"]).unsqueeze(1)'' operation, the ret['ref_imgs'] is a [b*3,1,h,w] tensor? Then, the "repeat_interleave(ref_dec_lens, dim=0)" count an error as mismatching the dimension.
Assume batch_size=2 and the ref_dec_lens is [2,3], then the shape is [6,1,h,w],which can not be operated by repeat_interleave(ref_dec_lens, dim=0).
I met this problem and can not start my train process, look forward for your help!
Metadata
Metadata
Assignees
Labels
No labels