-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ar.py
More file actions
74 lines (55 loc) · 2.1 KB
/
train_ar.py
File metadata and controls
74 lines (55 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from models import GInterface, MInterface
from dataset.get_datasets import build_dataloader,build_dataloader_var
from Utils.base_utils import load_model_path_by_config, ConfigLoader
import logging
from eval import eval
logger = logging.getLogger()
def load_callbacks(config):
callbacks = []
callbacks.append(plc.ModelCheckpoint(
monitor='fid',
filename='{epoch:02d}-{fid:.4f}',
save_top_k=1,
mode='min',
))
if config.lr_scheduler:
callbacks.append(plc.LearningRateMonitor(
logging_interval='step'))
return callbacks
def main():
data="stock"
config = ConfigLoader.load_var_config(config=f"configs/train_var_{data}.yaml")
logger = TensorBoardLogger(save_dir="log/", name=f"var_{data}")
pl.seed_everything(config.seed)
vqvae_path = load_model_path_by_config(config)
print('='*30 + f'{vqvae_path}'+'='*30 )
train_dataloader, val_dataloader=build_dataloader_var(config, data)
vqname = "DualVQVAE"
vqvae = MInterface(vqname, "mse", **config.vqvae_args)
vq_checkpoint = torch.load(vqvae_path, map_location=lambda storage, loc: storage)
vqvae.load_state_dict(vq_checkpoint['state_dict'], strict=True)
vqvae.model.requires_grad_(False)
genmodel = GInterface(vqvae.model , config.vqvae_args.v_patch_nums, data,"var",vqname,**config)
if config.resume:
ckpt = torch.load(config.var_path.load_dir)
genmodel.load_state_dict(ckpt['state_dict'], strict=True)
print(f'Successfully load the model! {config.var_path.load_dir}')
trainer = Trainer(
devices=config.devices,
max_epochs=900,
strategy='auto',
log_every_n_steps=1,
callbacks=load_callbacks(config),
check_val_every_n_epoch=50,
logger=logger,
)
trainer.fit(genmodel, train_dataloader, val_dataloader)
if __name__ == '__main__':
main()