|
22 | 22 | "cell_type": "markdown", |
23 | 23 | "metadata": {}, |
24 | 24 | "source": [ |
25 | | - "## Imports" |
| 25 | + "### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n", |
| 26 | + " [1. Imports](##sec1)\n", |
| 27 | + "\n", |
| 28 | + " [2. Configurations and utilities](##sec2)\n", |
| 29 | + "\n", |
| 30 | + " [3. Defining the lifting](##sec2)\n", |
| 31 | + "\n", |
| 32 | + " [4. Loading the data](##sec3)\n", |
| 33 | + "\n", |
| 34 | + " [5. Model initialization](##sec4)\n", |
| 35 | + "\n", |
| 36 | + " [6. Training](##sec5)\n", |
| 37 | + "\n", |
| 38 | + " [7. Testing the model](##sec6)" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "metadata": {}, |
| 44 | + "source": [ |
| 45 | + "## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>" |
26 | 46 | ] |
27 | 47 | }, |
28 | 48 | { |
|
60 | 80 | "cell_type": "markdown", |
61 | 81 | "metadata": {}, |
62 | 82 | "source": [ |
63 | | - "## Configurations and utilities" |
| 83 | + "## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>" |
64 | 84 | ] |
65 | 85 | }, |
66 | 86 | { |
|
152 | 172 | "cell_type": "markdown", |
153 | 173 | "metadata": {}, |
154 | 174 | "source": [ |
155 | | - "## Defining the lifting" |
| 175 | + "## 3. Defining the lifting <a class=\"anchor\" id=\"sec3\"></a>" |
156 | 176 | ] |
157 | 177 | }, |
158 | 178 | { |
|
206 | 226 | "cell_type": "markdown", |
207 | 227 | "metadata": {}, |
208 | 228 | "source": [ |
209 | | - "## Loading the data" |
| 229 | + "## 4. Loading the data <a class=\"anchor\" id=\"sec4\"></a>" |
210 | 230 | ] |
211 | 231 | }, |
212 | 232 | { |
|
233 | 253 | "metadata": {}, |
234 | 254 | "outputs": [ |
235 | 255 | { |
236 | | - "name": "stderr", |
| 256 | + "name": "stdout", |
237 | 257 | "output_type": "stream", |
238 | 258 | "text": [ |
239 | | - "Processing...\n", |
240 | | - "Done!\n" |
| 259 | + "Transform parameters are the same, using existing data_dir: ./data/MUTAG/MUTAG/clique_lifting/458544608\n" |
241 | 260 | ] |
242 | 261 | } |
243 | 262 | ], |
|
255 | 274 | "cell_type": "markdown", |
256 | 275 | "metadata": {}, |
257 | 276 | "source": [ |
258 | | - "## Model initialization" |
| 277 | + "## 5. Model initialization <a class=\"anchor\" id=\"sec5\"></a>" |
259 | 278 | ] |
260 | 279 | }, |
261 | 280 | { |
|
304 | 323 | "cell_type": "markdown", |
305 | 324 | "metadata": {}, |
306 | 325 | "source": [ |
307 | | - "## Training" |
| 326 | + "## 6. Training <a class=\"anchor\" id=\"sec6\"></a>" |
308 | 327 | ] |
309 | 328 | }, |
310 | 329 | { |
|
316 | 335 | }, |
317 | 336 | { |
318 | 337 | "cell_type": "code", |
319 | | - "execution_count": 10, |
| 338 | + "execution_count": 9, |
320 | 339 | "metadata": {}, |
321 | 340 | "outputs": [ |
322 | 341 | { |
323 | 342 | "name": "stderr", |
324 | 343 | "output_type": "stream", |
325 | 344 | "text": [ |
326 | | - "GPU available: True (mps), used: False\n", |
| 345 | + "GPU available: True (cuda), used: False\n", |
327 | 346 | "TPU available: False, using: 0 TPU cores\n", |
328 | 347 | "IPU available: False, using: 0 IPUs\n", |
329 | 348 | "HPU available: False, using: 0 HPUs\n", |
|
344 | 363 | } |
345 | 364 | ], |
346 | 365 | "source": [ |
| 366 | + "%%capture\n", |
347 | 367 | "# Increase the number of epochs to get better results\n", |
348 | 368 | "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n", |
349 | 369 | "\n", |
|
353 | 373 | }, |
354 | 374 | { |
355 | 375 | "cell_type": "code", |
356 | | - "execution_count": 11, |
| 376 | + "execution_count": 10, |
357 | 377 | "metadata": {}, |
358 | 378 | "outputs": [ |
359 | 379 | { |
360 | 380 | "name": "stdout", |
361 | 381 | "output_type": "stream", |
362 | 382 | "text": [ |
363 | | - "train/accuracy : 0.7573964595794678\n", |
364 | | - "train/precision : 0.737596869468689\n", |
365 | | - "train/recall : 0.6920425891876221\n", |
366 | | - "val/loss : 0.6638099551200867\n", |
367 | | - "val/accuracy : 0.7894737124443054\n", |
368 | | - "val/precision : 0.7749999761581421\n", |
369 | | - "val/recall : 0.7115384340286255\n", |
370 | | - "train/loss : 0.5858408808708191\n" |
| 383 | + " Training metrics\n", |
| 384 | + " --------------------------\n", |
| 385 | + "train/accuracy: 0.7633\n", |
| 386 | + "train/precision: 0.7352\n", |
| 387 | + "train/recall: 0.7310\n", |
| 388 | + "val/loss: 0.7276\n", |
| 389 | + "val/accuracy: 0.7895\n", |
| 390 | + "val/precision: 0.7750\n", |
| 391 | + "val/recall: 0.7115\n", |
| 392 | + "train/loss: 0.7212\n" |
371 | 393 | ] |
372 | 394 | } |
373 | 395 | ], |
374 | 396 | "source": [ |
| 397 | + "print(' Training metrics\\n', '-'*26)\n", |
375 | 398 | "for key in train_metrics:\n", |
376 | | - " print(key,\": \", train_metrics[key].item())" |
| 399 | + " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" |
377 | 400 | ] |
378 | 401 | }, |
379 | 402 | { |
380 | 403 | "cell_type": "markdown", |
381 | 404 | "metadata": {}, |
382 | 405 | "source": [ |
383 | | - "## Testing the model" |
| 406 | + "## 7. Testing the model <a class=\"anchor\" id=\"sec7\"></a>" |
384 | 407 | ] |
385 | 408 | }, |
386 | 409 | { |
|
390 | 413 | "Finally, we can test the model and obtain the results." |
391 | 414 | ] |
392 | 415 | }, |
| 416 | + { |
| 417 | + "cell_type": "code", |
| 418 | + "execution_count": 11, |
| 419 | + "metadata": {}, |
| 420 | + "outputs": [], |
| 421 | + "source": [ |
| 422 | + "%%capture\n", |
| 423 | + "trainer.test(model, datamodule)\n", |
| 424 | + "test_metrics = trainer.callback_metrics" |
| 425 | + ] |
| 426 | + }, |
393 | 427 | { |
394 | 428 | "cell_type": "code", |
395 | 429 | "execution_count": 12, |
|
399 | 433 | "name": "stdout", |
400 | 434 | "output_type": "stream", |
401 | 435 | "text": [ |
402 | | - "\n" |
403 | | - ] |
404 | | - }, |
405 | | - { |
406 | | - "name": "stderr", |
407 | | - "output_type": "stream", |
408 | | - "text": [ |
409 | | - "/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.\n" |
| 436 | + " Testing metrics\n", |
| 437 | + " -------------------------\n", |
| 438 | + "test/loss: 0.7276\n", |
| 439 | + "test/accuracy: 0.7895\n", |
| 440 | + "test/precision: 0.7750\n", |
| 441 | + "test/recall: 0.7115\n" |
410 | 442 | ] |
411 | | - }, |
412 | | - { |
413 | | - "data": { |
414 | | - "text/html": [ |
415 | | - "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
416 | | - "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
417 | | - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
418 | | - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/accuracy </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7894737124443054 </span>│\n", |
419 | | - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.6638099551200867 </span>│\n", |
420 | | - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/precision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7749999761581421 </span>│\n", |
421 | | - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/recall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7115384340286255 </span>│\n", |
422 | | - "└───────────────────────────┴───────────────────────────┘\n", |
423 | | - "</pre>\n" |
424 | | - ], |
425 | | - "text/plain": [ |
426 | | - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
427 | | - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
428 | | - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
429 | | - "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7894737124443054 \u001b[0m\u001b[35m \u001b[0m│\n", |
430 | | - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6638099551200867 \u001b[0m\u001b[35m \u001b[0m│\n", |
431 | | - "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7749999761581421 \u001b[0m\u001b[35m \u001b[0m│\n", |
432 | | - "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7115384340286255 \u001b[0m\u001b[35m \u001b[0m│\n", |
433 | | - "└───────────────────────────┴───────────────────────────┘\n" |
434 | | - ] |
435 | | - }, |
436 | | - "metadata": {}, |
437 | | - "output_type": "display_data" |
438 | 443 | } |
439 | 444 | ], |
440 | 445 | "source": [ |
441 | | - "trainer.test(model, datamodule)\n", |
442 | | - "test_metrics = trainer.callback_metrics" |
| 446 | + "print(' Testing metrics\\n', '-'*25)\n", |
| 447 | + "for key in test_metrics:\n", |
| 448 | + " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" |
443 | 449 | ] |
444 | | - }, |
445 | | - { |
446 | | - "cell_type": "code", |
447 | | - "execution_count": null, |
448 | | - "metadata": {}, |
449 | | - "outputs": [], |
450 | | - "source": [] |
451 | 450 | } |
452 | 451 | ], |
453 | 452 | "metadata": { |
|
0 commit comments