diff --git a/train.py b/train.py index 7f3a66bf00..9d0471bd3c 100644 --- a/train.py +++ b/train.py @@ -80,8 +80,14 @@ def train(opt): model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') + # Fine tunning 목적 if opt.FT: - model.load_state_dict(torch.load(opt.saved_model), strict=False) + checkpoint = torch.load(opt.saved_model) + checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)} + for name in model.state_dict().keys() : + if name in checkpoint.keys() : + model.state_dict()[name].copy_(checkpoint[name]) + #model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") diff --git a/utils.py b/utils.py index e576358418..f6cad8cbc6 100644 --- a/utils.py +++ b/utils.py @@ -32,8 +32,20 @@ def encode(self, text, batch_max_length=25): batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) for i, t in enumerate(text): text = list(t) - text = [self.dict[char] for char in text] - batch_text[i][:len(text)] = torch.LongTensor(text) + # Could occur Dict Key Error. So, should check 'char' in self.dict. + # If there isn't char in self.dict, it will be ignored. + # Should drop all data including that char. because it could make train worse. + text_index = [] + for char in text: + if char not in self.dict: + text_index = [] + break + text_index.append(self.dict[char]) + + batch_text[i][:len(text_index)] = torch.LongTensor(text_index) + + #text = [self.dict[char] for char in text if char in self.dict] + #batch_text[i][:len(text)] = torch.LongTensor(text) return (batch_text.to(device), torch.IntTensor(length).to(device)) def decode(self, text_index, length):