Skip to content

Commit 1ecfaef

Browse files
ADD: optimizer state can now be saved & reinstantiated (#279)
1 parent 6da27aa commit 1ecfaef

File tree

3 files changed

+343
-7
lines changed

3 files changed

+343
-7
lines changed

choice_learn/models/base_model.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def batch_predict(
516516
}
517517
return batch_loss, probabilities
518518

519-
def save_model(self, path):
519+
def save_model(self, path, save_opt=True):
520520
"""Save the different models on disk.
521521
522522
Parameters
@@ -538,16 +538,34 @@ def save_model(self, path):
538538
elif isinstance(v, (list, tuple)):
539539
if all(isinstance(item, (int, float, str, dict)) for item in v):
540540
params[k] = v
541-
else:
541+
elif k != "_trainable_weights":
542542
logging.warning(
543543
"""Attribute '%s' is a list with non-serializable
544544
types and will not be saved.""",
545545
k,
546546
)
547-
with open(os.path.join(path, "params.json"), "w") as f:
547+
with open(Path(path) / "params.json", "w") as f:
548548
json.dump(params, f)
549549

550550
# Save optimizer state
551+
if save_opt and not isinstance(self.optimizer, str):
552+
(Path(path) / "optimizer").mkdir(parents=True, exist_ok=True)
553+
config = self.optimizer.get_config()
554+
weights_store = {}
555+
self.optimizer.save_own_variables(weights_store)
556+
for key, value in weights_store.items():
557+
if isinstance(value, tf.Variable):
558+
value = value.numpy()
559+
weights_store[key] = value.tolist()
560+
if "learning_rate" in config.keys():
561+
if isinstance(config["learning_rate"], tf.Variable):
562+
config["learning_rate"] = config["learning_rate"].numpy()
563+
if isinstance(config["learning_rate"], np.float32):
564+
config["learning_rate"] = config["learning_rate"].tolist()
565+
with open(Path(path) / "optimizer" / "config.json", "w") as f:
566+
json.dump(config, f)
567+
with open(Path(path) / "optimizer" / "weights_store.json", "w") as f:
568+
json.dump(weights_store, f)
551569

552570
@classmethod
553571
def load_model(cls, path):
@@ -563,7 +581,11 @@ def load_model(cls, path):
563581
ChoiceModel
564582
Loaded ChoiceModel
565583
"""
566-
obj = cls()
584+
# To improve for non string attributes
585+
with open(Path(path) / "params.json") as f:
586+
params = json.load(f)
587+
588+
obj = cls(optimizer=params["optimizer_name"])
567589
obj._trainable_weights = []
568590

569591
i = 0
@@ -576,11 +598,22 @@ def load_model(cls, path):
576598
i += 1
577599
weight_path = f"weight_{i}.npy"
578600

579-
# To improve for non string attributes
580-
params = json.load(open(Path(path) / "params.json"))
581601
for k, v in params.items():
582602
setattr(obj, k, v)
583603

604+
if Path.is_dir(Path(path) / "optimizer"):
605+
with open(Path(path) / "optimizer" / "config.json") as f:
606+
config = json.load(f)
607+
# obj.optimizer = tf.keras.optimizers.get(params["optimizer_name"]).from_config(config)
608+
obj.optimizer = obj.optimizer.from_config(config)
609+
obj.optimizer.build(var_list=obj.trainable_weights)
610+
611+
with open(Path(path) / "optimizer" / "weights_store.json") as f:
612+
store = json.load(f)
613+
for key, value in store.items():
614+
store[key] = np.array(value, dtype=np.float32)
615+
obj.optimizer.load_own_variables(store)
616+
584617
# Load optimizer step
585618
return obj
586619

notebooks/auxiliary_tools/assortment_example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@
961961
"name": "python",
962962
"nbconvert_exporter": "python",
963963
"pygments_lexer": "ipython3",
964-
"version": "3.8.18"
964+
"version": "3.11.4"
965965
}
966966
},
967967
"nbformat": 4,
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"# Install necessary requirements\n",
11+
"\n",
12+
"# If you run this notebook on Google Colab, or in standalone mode, you need to install the required packages.\n",
13+
"# Uncomment the following lines:\n",
14+
"\n",
15+
"# !pip install choice-learn\n",
16+
"\n",
17+
"# If you run the notebook within the GitHub repository, you need to run the following lines, that can skipped otherwise:\n",
18+
"import os\n",
19+
"import sys\n",
20+
"\n",
21+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
22+
"sys.path.append(\"../../\")"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"id": "1",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"\n",
33+
"import numpy as np\n",
34+
"import tensorflow as tf\n",
35+
"\n",
36+
"# Enabling eager execution sometimes decreases fitting time\n",
37+
"tf.compat.v1.enable_eager_execution()"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"id": "2",
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"from choice_learn.models import ConditionalLogit"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"id": "3",
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"from choice_learn.datasets import load_swissmetro\n",
58+
"\n",
59+
"swiss_dataset = load_swissmetro(preprocessing=\"tutorial\")\n",
60+
"print(swiss_dataset.summary())"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"id": "4",
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"# Initialization of the model\n",
71+
"swiss_model = ConditionalLogit(optimizer=\"Adam\", epochs=25, lr=0.01)\n",
72+
"\n",
73+
"# Intercept for train & sm\n",
74+
"swiss_model.add_coefficients(feature_name=\"intercept\", items_indexes=[0, 1])\n",
75+
"# beta_he for train & sm\n",
76+
"swiss_model.add_coefficients(feature_name=\"headway\",\n",
77+
" items_indexes=[0, 1],\n",
78+
" coefficient_name=\"beta_he\")\n",
79+
"# beta_co for all items\n",
80+
"swiss_model.add_coefficients(feature_name=\"cost\",\n",
81+
" items_indexes=[0, 1, 2])\n",
82+
"# beta first_class for train\n",
83+
"swiss_model.add_coefficients(feature_name=\"regular_class\",\n",
84+
" items_indexes=[0])\n",
85+
"# beta seats for train\n",
86+
"swiss_model.add_coefficients(feature_name=\"seats\", items_indexes=[1])\n",
87+
"# betas luggage for car\n",
88+
"swiss_model.add_coefficients(feature_name=\"single_luggage_piece\",\n",
89+
" items_indexes=[2],\n",
90+
" coefficient_name=\"beta_luggage=1\")\n",
91+
"swiss_model.add_coefficients(feature_name=\"multiple_luggage_piece\",\n",
92+
" items_indexes=[2],\n",
93+
" coefficient_name=\"beta_luggage>1\")\n",
94+
"# beta TT only for car\n",
95+
"swiss_model.add_coefficients(feature_name=\"travel_time\",\n",
96+
" items_indexes=[2],\n",
97+
" coefficient_name=\"beta_tt_car\")\n",
98+
"\n",
99+
"# betas TT and HE shared by train and sm\n",
100+
"swiss_model.add_shared_coefficient(feature_name=\"travel_time\",\n",
101+
" items_indexes=[0, 1])\n",
102+
"swiss_model.add_shared_coefficient(feature_name=\"train_survey\",\n",
103+
" items_indexes=[0, 1],\n",
104+
" coefficient_name=\"beta_survey\")\n"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"id": "5",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"# Estimation of the model\n",
115+
"history = swiss_model.fit(swiss_dataset, get_report=False)"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"id": "6",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"isinstance(swiss_model.optimizer.get_config()[\"learning_rate\"], np.float32), isinstance(swiss_model.optimizer.get_config()[\"learning_rate\"], np.ndarray)"
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"id": "7",
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
135+
"swiss_model.save_model(\"test_save\")"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"id": "8",
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"swiss_model2 = ConditionalLogit.load_model(\"test_save\")"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"id": "9",
152+
"metadata": {},
153+
"outputs": [],
154+
"source": [
155+
"hist = swiss_model2.fit(swiss_dataset)"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"id": "10",
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"import shutil\n",
166+
"\n",
167+
"shutil.rmtree(\"test_save\")"
168+
]
169+
},
170+
{
171+
"cell_type": "markdown",
172+
"id": "11",
173+
"metadata": {},
174+
"source": [
175+
"## Save every n epochs with a custom tf.Callback"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": null,
181+
"id": "12",
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"class SaveCallback(tf.keras.callbacks.Callback):\n",
186+
" \"\"\"Callback to save regularly the model during training.\"\"\"\n",
187+
"\n",
188+
" def __init__(self, base_dir, save_every_n, *args, **kwargs):\n",
189+
" \"\"\"Instantiate callback.\"\"\"\n",
190+
" self.base_dir = base_dir\n",
191+
" self.save_every_n = save_every_n\n",
192+
" super().__init__(*args, **kwargs)\n",
193+
"\n",
194+
" def on_epoch_end(self, epoch, logs=None):\n",
195+
" \"\"\"Define saving at the end of each epoch.\"\"\"\n",
196+
" _ = logs\n",
197+
" if (epoch + 1) % self.save_every_n == 0:\n",
198+
" self._save_model(epoch=epoch)\n",
199+
"\n",
200+
" def _save_model(self, epoch):\n",
201+
" \"\"\"Handle model saving internally.\"\"\"\n",
202+
" dirname = os.path.join(self.base_dir, f\"epoch_{epoch}\")\n",
203+
" self.model.save_model(dirname)"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": null,
209+
"id": "13",
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"# Initialization of the model\n",
214+
"swiss_model = ConditionalLogit(optimizer=\"Adam\", epochs=25, lr=0.01, callbacks=[SaveCallback(base_dir=\"test_save_cb\", save_every_n=2)])\n",
215+
"\n",
216+
"# Intercept for train & sm\n",
217+
"swiss_model.add_coefficients(feature_name=\"intercept\", items_indexes=[0, 1])\n",
218+
"# beta_he for train & sm\n",
219+
"swiss_model.add_coefficients(feature_name=\"headway\",\n",
220+
" items_indexes=[0, 1],\n",
221+
" coefficient_name=\"beta_he\")\n",
222+
"# beta_co for all items\n",
223+
"swiss_model.add_coefficients(feature_name=\"cost\",\n",
224+
" items_indexes=[0, 1, 2])\n",
225+
"# beta first_class for train\n",
226+
"swiss_model.add_coefficients(feature_name=\"regular_class\",\n",
227+
" items_indexes=[0])\n",
228+
"# beta seats for train\n",
229+
"swiss_model.add_coefficients(feature_name=\"seats\", items_indexes=[1])\n",
230+
"# betas luggage for car\n",
231+
"swiss_model.add_coefficients(feature_name=\"single_luggage_piece\",\n",
232+
" items_indexes=[2],\n",
233+
" coefficient_name=\"beta_luggage=1\")\n",
234+
"swiss_model.add_coefficients(feature_name=\"multiple_luggage_piece\",\n",
235+
" items_indexes=[2],\n",
236+
" coefficient_name=\"beta_luggage>1\")\n",
237+
"# beta TT only for car\n",
238+
"swiss_model.add_coefficients(feature_name=\"travel_time\",\n",
239+
" items_indexes=[2],\n",
240+
" coefficient_name=\"beta_tt_car\")\n",
241+
"\n",
242+
"# betas TT and HE shared by train and sm\n",
243+
"swiss_model.add_shared_coefficient(feature_name=\"travel_time\",\n",
244+
" items_indexes=[0, 1])\n",
245+
"swiss_model.add_shared_coefficient(feature_name=\"train_survey\",\n",
246+
" items_indexes=[0, 1],\n",
247+
" coefficient_name=\"beta_survey\")\n"
248+
]
249+
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": null,
253+
"id": "14",
254+
"metadata": {},
255+
"outputs": [],
256+
"source": [
257+
"\n",
258+
"# Estimation of the model\n",
259+
"history = swiss_model.fit(swiss_dataset, get_report=True)"
260+
]
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": null,
265+
"id": "15",
266+
"metadata": {},
267+
"outputs": [],
268+
"source": [
269+
"# remove\n",
270+
"shutil.rmtree(\"test_save_cb\")"
271+
]
272+
},
273+
{
274+
"cell_type": "code",
275+
"execution_count": null,
276+
"id": "16",
277+
"metadata": {},
278+
"outputs": [],
279+
"source": []
280+
}
281+
],
282+
"metadata": {
283+
"kernelspec": {
284+
"display_name": "tf_env",
285+
"language": "python",
286+
"name": "python3"
287+
},
288+
"language_info": {
289+
"codemirror_mode": {
290+
"name": "ipython",
291+
"version": 3
292+
},
293+
"file_extension": ".py",
294+
"mimetype": "text/x-python",
295+
"name": "python",
296+
"nbconvert_exporter": "python",
297+
"pygments_lexer": "ipython3",
298+
"version": "3.11.4"
299+
}
300+
},
301+
"nbformat": 4,
302+
"nbformat_minor": 5
303+
}

0 commit comments

Comments
 (0)