Skip to content

Commit c0f6824

Browse files
update & fix notebook with recent changes (#276)
1 parent 1ecfaef commit c0f6824

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

notebooks/auxiliary_tools/assortment_example.ipynb

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,18 @@
2929
"source": [
3030
"# Install necessary requirements\n",
3131
"\n",
32-
"# If you run this notebook on Google Colab, or in standalone mode, you need to install the required packages.\n",
33-
"# Uncomment the following lines:\n",
32+
"# If you run this notebook on Google Colab, or in standalone mode, you need to install\n",
33+
"# the required packages. Just uncomment the following lines:\n",
3434
"\n",
3535
"# !pip install choice-learn\n",
3636
"\n",
37-
"# If you run the notebook within the GitHub repository, you need to run the following lines, that can skipped otherwise:\n",
37+
"# If you run the notebook within the GitHub repository, you need to run the following lines,\n",
38+
"# that can skipped otherwise:\n",
3839
"import os\n",
3940
"import sys\n",
4041
"\n",
41-
"sys.path.append(\"../../\")"
42-
]
43-
},
44-
{
45-
"cell_type": "code",
46-
"execution_count": null,
47-
"metadata": {},
48-
"outputs": [],
49-
"source": [
50-
"# Importing the right base libraries\n",
51-
"import os\n",
42+
"sys.path.append(\"../../\")\n",
43+
"\n",
5244
"# Remove GPU use\n",
5345
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
5446
"\n",
@@ -277,36 +269,44 @@
277269
"outputs": [],
278270
"source": [
279271
"import tensorflow as tf\n",
272+
"\n",
280273
"from choice_learn.models.base_model import ChoiceModel\n",
281274
"\n",
282275
"\n",
283276
"class TaFengMNL(ChoiceModel):\n",
284277
" \"\"\"Custom model for the TaFeng dataset.\"\"\"\n",
285278
"\n",
286279
" def __init__(self, **kwargs):\n",
287-
" \"\"\"Instantiation of our custom model.\"\"\"\n",
280+
" \"\"\"Instantiate of our custom model.\"\"\"\n",
288281
" # Standard inheritance stuff\n",
289282
" super().__init__(**kwargs)\n",
290-
"\n",
291-
" # Instantiation of base utilties weights\n",
292-
" # We have 25 items in the dataset making 25 weights\n",
293-
" self.base_utilities = tf.Variable(\n",
294-
" tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 25))\n",
295-
" )\n",
296-
" # Instantiation of price elasticities weights\n",
297-
" # We have 3 age categories making 3 weights\n",
298-
" self.price_elasticities = tf.Variable(\n",
299-
" tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 3))\n",
300-
" )\n",
301-
" # Don't forget to add the weights to be optimized in self.weights !\n",
302-
" self.trainable_weights = [self.base_utilities, self.price_elasticities]\n",
283+
" # Directly initialize weights in _trainable_weights\n",
284+
" self._trainable_weights = [\n",
285+
" tf.Variable(tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 25))),\n",
286+
" tf.Variable(tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 3)))\n",
287+
" ]\n",
288+
"\n",
289+
" @property\n",
290+
" def trainable_weights(self):\n",
291+
" \"\"\"Return all the models trainable weights.\"\"\"\n",
292+
" return self._trainable_weights\n",
293+
"\n",
294+
" @property\n",
295+
" def base_utilities(self):\n",
296+
" \"\"\"Return itemwise utilities.\"\"\"\n",
297+
" return self._trainable_weights[0]\n",
298+
"\n",
299+
" @property\n",
300+
" def price_elasticities(self):\n",
301+
" \"\"\"Return the price elasticities.\"\"\"\n",
302+
" return self._trainable_weights[1]\n",
303303
"\n",
304304
" def compute_batch_utility(self,\n",
305305
" shared_features_by_choice,\n",
306306
" items_features_by_choice,\n",
307307
" available_items_by_choice,\n",
308308
" choices):\n",
309-
" \"\"\"Method that defines how the model computes the utility of a product.\n",
309+
" \"\"\"Define how the model computes the utility of a product.\n",
310310
"\n",
311311
" Parameters\n",
312312
" ----------\n",
@@ -323,7 +323,7 @@
323323
" Choices\n",
324324
" Shape must be (n_choices, )\n",
325325
"\n",
326-
" Returns:\n",
326+
" Return:\n",
327327
" --------\n",
328328
" np.ndarray\n",
329329
" Utility of each product for each choice.\n",
@@ -375,7 +375,7 @@
375375
}
376376
],
377377
"source": [
378-
"model = TaFengMNL(optimizer=\"lbfgs\", epochs=1000, tolerance=1e-4)\n",
378+
"model = TaFengMNL(optimizer=\"lbfgs\", epochs=1000, lbfgs_tolerance=1e-4)\n",
379379
"history = model.fit(dataset, verbose=1)"
380380
]
381381
},

0 commit comments

Comments
 (0)