-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel_def.py
70 lines (59 loc) · 2.3 KB
/
model_def.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import tensorflow as tf
from data import download, load_dataset
from pix2pix import Pix2Pix, make_discriminator_optimizer, make_generator_optimizer
from determined.keras import InputData, TFKerasTrial, TFKerasTrialContext
class Pix2PixTrial(TFKerasTrial):
def __init__(self, context: TFKerasTrialContext) -> None:
self.context = context
self.path = download(
self.context.get_data_config()["base"],
self.context.get_data_config()["dataset"],
)
def build_model(self) -> tf.keras.models.Model:
model = Pix2Pix()
# Wrap the model
model = self.context.wrap_model(model)
# Create and wrap the optimizers
g_optimizer = self.context.wrap_optimizer(
make_generator_optimizer(
lr=self.context.get_hparam("generator_lr"),
beta_1=self.context.get_hparam("generator_beta_1"),
)
)
d_optimizer = self.context.wrap_optimizer(
make_discriminator_optimizer(
lr=self.context.get_hparam("discriminator_lr"),
beta_1=self.context.get_hparam("discriminator_beta_1"),
)
)
model.compile(
discriminator_optimizer=d_optimizer,
generator_optimizer=g_optimizer,
)
return model
def _get_wrapped_dataset(self, set_) -> InputData:
ds = load_dataset(
self.path,
self.context.get_data_config()["height"],
self.context.get_data_config()["width"],
set_,
self.context.get_hparam("jitter"),
self.context.get_hparam("mirror"),
)
ds = self.context.wrap_dataset(ds)
return ds
def build_training_data_loader(self) -> InputData:
train_dataset = (
self._get_wrapped_dataset("train")
.cache()
.shuffle(self.context.get_data_config().get("BUFFER_SIZE"))
.batch(self.context.get_per_slot_batch_size())
.repeat()
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)
return train_dataset
def build_validation_data_loader(self) -> InputData:
test_dataset = self._get_wrapped_dataset("test").batch(
self.context.get_per_slot_batch_size()
)
return test_dataset