|
29 | 29 | "source": [ |
30 | 30 | "# Install necessary requirements\n", |
31 | 31 | "\n", |
32 | | - "# If you run this notebook on Google Colab, or in standalone mode, you need to install the required packages.\n", |
33 | | - "# Uncomment the following lines:\n", |
| 32 | + "# If you run this notebook on Google Colab, or in standalone mode, you need to install\n", |
| 33 | + "# the required packages. Just uncomment the following lines:\n", |
34 | 34 | "\n", |
35 | 35 | "# !pip install choice-learn\n", |
36 | 36 | "\n", |
37 | | - "# If you run the notebook within the GitHub repository, you need to run the following lines, that can skipped otherwise:\n", |
| 37 | + "# If you run the notebook within the GitHub repository, you need to run the following lines,\n", |
| 38 | + "# that can skipped otherwise:\n", |
38 | 39 | "import os\n", |
39 | 40 | "import sys\n", |
40 | 41 | "\n", |
41 | | - "sys.path.append(\"../../\")" |
42 | | - ] |
43 | | - }, |
44 | | - { |
45 | | - "cell_type": "code", |
46 | | - "execution_count": null, |
47 | | - "metadata": {}, |
48 | | - "outputs": [], |
49 | | - "source": [ |
50 | | - "# Importing the right base libraries\n", |
51 | | - "import os\n", |
| 42 | + "sys.path.append(\"../../\")\n", |
| 43 | + "\n", |
52 | 44 | "# Remove GPU use\n", |
53 | 45 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", |
54 | 46 | "\n", |
|
277 | 269 | "outputs": [], |
278 | 270 | "source": [ |
279 | 271 | "import tensorflow as tf\n", |
| 272 | + "\n", |
280 | 273 | "from choice_learn.models.base_model import ChoiceModel\n", |
281 | 274 | "\n", |
282 | 275 | "\n", |
283 | 276 | "class TaFengMNL(ChoiceModel):\n", |
284 | 277 | " \"\"\"Custom model for the TaFeng dataset.\"\"\"\n", |
285 | 278 | "\n", |
286 | 279 | " def __init__(self, **kwargs):\n", |
287 | | - " \"\"\"Instantiation of our custom model.\"\"\"\n", |
| 280 | + " \"\"\"Instantiate of our custom model.\"\"\"\n", |
288 | 281 | " # Standard inheritance stuff\n", |
289 | 282 | " super().__init__(**kwargs)\n", |
290 | | - "\n", |
291 | | - " # Instantiation of base utilties weights\n", |
292 | | - " # We have 25 items in the dataset making 25 weights\n", |
293 | | - " self.base_utilities = tf.Variable(\n", |
294 | | - " tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 25))\n", |
295 | | - " )\n", |
296 | | - " # Instantiation of price elasticities weights\n", |
297 | | - " # We have 3 age categories making 3 weights\n", |
298 | | - " self.price_elasticities = tf.Variable(\n", |
299 | | - " tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 3))\n", |
300 | | - " )\n", |
301 | | - " # Don't forget to add the weights to be optimized in self.weights !\n", |
302 | | - " self.trainable_weights = [self.base_utilities, self.price_elasticities]\n", |
| 283 | + " # Directly initialize weights in _trainable_weights\n", |
| 284 | + " self._trainable_weights = [\n", |
| 285 | + " tf.Variable(tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 25))),\n", |
| 286 | + " tf.Variable(tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(1, 3)))\n", |
| 287 | + " ]\n", |
| 288 | + "\n", |
| 289 | + " @property\n", |
| 290 | + " def trainable_weights(self):\n", |
| 291 | + " \"\"\"Return all the models trainable weights.\"\"\"\n", |
| 292 | + " return self._trainable_weights\n", |
| 293 | + "\n", |
| 294 | + " @property\n", |
| 295 | + " def base_utilities(self):\n", |
| 296 | + " \"\"\"Return itemwise utilities.\"\"\"\n", |
| 297 | + " return self._trainable_weights[0]\n", |
| 298 | + "\n", |
| 299 | + " @property\n", |
| 300 | + " def price_elasticities(self):\n", |
| 301 | + " \"\"\"Return the price elasticities.\"\"\"\n", |
| 302 | + " return self._trainable_weights[1]\n", |
303 | 303 | "\n", |
304 | 304 | " def compute_batch_utility(self,\n", |
305 | 305 | " shared_features_by_choice,\n", |
306 | 306 | " items_features_by_choice,\n", |
307 | 307 | " available_items_by_choice,\n", |
308 | 308 | " choices):\n", |
309 | | - " \"\"\"Method that defines how the model computes the utility of a product.\n", |
| 309 | + " \"\"\"Define how the model computes the utility of a product.\n", |
310 | 310 | "\n", |
311 | 311 | " Parameters\n", |
312 | 312 | " ----------\n", |
|
323 | 323 | " Choices\n", |
324 | 324 | " Shape must be (n_choices, )\n", |
325 | 325 | "\n", |
326 | | - " Returns:\n", |
| 326 | + " Return:\n", |
327 | 327 | " --------\n", |
328 | 328 | " np.ndarray\n", |
329 | 329 | " Utility of each product for each choice.\n", |
|
375 | 375 | } |
376 | 376 | ], |
377 | 377 | "source": [ |
378 | | - "model = TaFengMNL(optimizer=\"lbfgs\", epochs=1000, tolerance=1e-4)\n", |
| 378 | + "model = TaFengMNL(optimizer=\"lbfgs\", epochs=1000, lbfgs_tolerance=1e-4)\n", |
379 | 379 | "history = model.fit(dataset, verbose=1)" |
380 | 380 | ] |
381 | 381 | }, |
|
0 commit comments