|
22 | 22 | "cell_type": "markdown",
|
23 | 23 | "metadata": {},
|
24 | 24 | "source": [
|
25 |
| - "## Imports" |
| 25 | + "### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n", |
| 26 | + " [1. Imports](##sec1)\n", |
| 27 | + "\n", |
| 28 | + " [2. Configurations and utilities](##sec2)\n", |
| 29 | + "\n", |
| 30 | + " [3. Defining the lifting](##sec2)\n", |
| 31 | + "\n", |
| 32 | + " [4. Loading the data](##sec3)\n", |
| 33 | + "\n", |
| 34 | + " [5. Model initialization](##sec4)\n", |
| 35 | + "\n", |
| 36 | + " [6. Training](##sec5)\n", |
| 37 | + "\n", |
| 38 | + " [7. Testing the model](##sec6)" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "metadata": {}, |
| 44 | + "source": [ |
| 45 | + "## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>" |
26 | 46 | ]
|
27 | 47 | },
|
28 | 48 | {
|
|
60 | 80 | "cell_type": "markdown",
|
61 | 81 | "metadata": {},
|
62 | 82 | "source": [
|
63 |
| - "## Configurations and utilities" |
| 83 | + "## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>" |
64 | 84 | ]
|
65 | 85 | },
|
66 | 86 | {
|
|
152 | 172 | "cell_type": "markdown",
|
153 | 173 | "metadata": {},
|
154 | 174 | "source": [
|
155 |
| - "## Defining the lifting" |
| 175 | + "## 3. Defining the lifting <a class=\"anchor\" id=\"sec3\"></a>" |
156 | 176 | ]
|
157 | 177 | },
|
158 | 178 | {
|
|
206 | 226 | "cell_type": "markdown",
|
207 | 227 | "metadata": {},
|
208 | 228 | "source": [
|
209 |
| - "## Loading the data" |
| 229 | + "## 4. Loading the data <a class=\"anchor\" id=\"sec4\"></a>" |
210 | 230 | ]
|
211 | 231 | },
|
212 | 232 | {
|
|
233 | 253 | "metadata": {},
|
234 | 254 | "outputs": [
|
235 | 255 | {
|
236 |
| - "name": "stderr", |
| 256 | + "name": "stdout", |
237 | 257 | "output_type": "stream",
|
238 | 258 | "text": [
|
239 |
| - "Processing...\n", |
240 |
| - "Done!\n" |
| 259 | + "Transform parameters are the same, using existing data_dir: ./data/MUTAG/MUTAG/clique_lifting/458544608\n" |
241 | 260 | ]
|
242 | 261 | }
|
243 | 262 | ],
|
|
255 | 274 | "cell_type": "markdown",
|
256 | 275 | "metadata": {},
|
257 | 276 | "source": [
|
258 |
| - "## Model initialization" |
| 277 | + "## 5. Model initialization <a class=\"anchor\" id=\"sec5\"></a>" |
259 | 278 | ]
|
260 | 279 | },
|
261 | 280 | {
|
|
304 | 323 | "cell_type": "markdown",
|
305 | 324 | "metadata": {},
|
306 | 325 | "source": [
|
307 |
| - "## Training" |
| 326 | + "## 6. Training <a class=\"anchor\" id=\"sec6\"></a>" |
308 | 327 | ]
|
309 | 328 | },
|
310 | 329 | {
|
|
316 | 335 | },
|
317 | 336 | {
|
318 | 337 | "cell_type": "code",
|
319 |
| - "execution_count": 10, |
| 338 | + "execution_count": 9, |
320 | 339 | "metadata": {},
|
321 | 340 | "outputs": [
|
322 | 341 | {
|
323 | 342 | "name": "stderr",
|
324 | 343 | "output_type": "stream",
|
325 | 344 | "text": [
|
326 |
| - "GPU available: True (mps), used: False\n", |
| 345 | + "GPU available: True (cuda), used: False\n", |
327 | 346 | "TPU available: False, using: 0 TPU cores\n",
|
328 | 347 | "IPU available: False, using: 0 IPUs\n",
|
329 | 348 | "HPU available: False, using: 0 HPUs\n",
|
|
344 | 363 | }
|
345 | 364 | ],
|
346 | 365 | "source": [
|
| 366 | + "%%capture\n", |
347 | 367 | "# Increase the number of epochs to get better results\n",
|
348 | 368 | "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n",
|
349 | 369 | "\n",
|
|
353 | 373 | },
|
354 | 374 | {
|
355 | 375 | "cell_type": "code",
|
356 |
| - "execution_count": 11, |
| 376 | + "execution_count": 10, |
357 | 377 | "metadata": {},
|
358 | 378 | "outputs": [
|
359 | 379 | {
|
360 | 380 | "name": "stdout",
|
361 | 381 | "output_type": "stream",
|
362 | 382 | "text": [
|
363 |
| - "train/accuracy : 0.7573964595794678\n", |
364 |
| - "train/precision : 0.737596869468689\n", |
365 |
| - "train/recall : 0.6920425891876221\n", |
366 |
| - "val/loss : 0.6638099551200867\n", |
367 |
| - "val/accuracy : 0.7894737124443054\n", |
368 |
| - "val/precision : 0.7749999761581421\n", |
369 |
| - "val/recall : 0.7115384340286255\n", |
370 |
| - "train/loss : 0.5858408808708191\n" |
| 383 | + " Training metrics\n", |
| 384 | + " --------------------------\n", |
| 385 | + "train/accuracy: 0.7633\n", |
| 386 | + "train/precision: 0.7352\n", |
| 387 | + "train/recall: 0.7310\n", |
| 388 | + "val/loss: 0.7276\n", |
| 389 | + "val/accuracy: 0.7895\n", |
| 390 | + "val/precision: 0.7750\n", |
| 391 | + "val/recall: 0.7115\n", |
| 392 | + "train/loss: 0.7212\n" |
371 | 393 | ]
|
372 | 394 | }
|
373 | 395 | ],
|
374 | 396 | "source": [
|
| 397 | + "print(' Training metrics\\n', '-'*26)\n", |
375 | 398 | "for key in train_metrics:\n",
|
376 |
| - " print(key,\": \", train_metrics[key].item())" |
| 399 | + " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" |
377 | 400 | ]
|
378 | 401 | },
|
379 | 402 | {
|
380 | 403 | "cell_type": "markdown",
|
381 | 404 | "metadata": {},
|
382 | 405 | "source": [
|
383 |
| - "## Testing the model" |
| 406 | + "## 7. Testing the model <a class=\"anchor\" id=\"sec7\"></a>" |
384 | 407 | ]
|
385 | 408 | },
|
386 | 409 | {
|
|
390 | 413 | "Finally, we can test the model and obtain the results."
|
391 | 414 | ]
|
392 | 415 | },
|
| 416 | + { |
| 417 | + "cell_type": "code", |
| 418 | + "execution_count": 11, |
| 419 | + "metadata": {}, |
| 420 | + "outputs": [], |
| 421 | + "source": [ |
| 422 | + "%%capture\n", |
| 423 | + "trainer.test(model, datamodule)\n", |
| 424 | + "test_metrics = trainer.callback_metrics" |
| 425 | + ] |
| 426 | + }, |
393 | 427 | {
|
394 | 428 | "cell_type": "code",
|
395 | 429 | "execution_count": 12,
|
|
399 | 433 | "name": "stdout",
|
400 | 434 | "output_type": "stream",
|
401 | 435 | "text": [
|
402 |
| - "\n" |
403 |
| - ] |
404 |
| - }, |
405 |
| - { |
406 |
| - "name": "stderr", |
407 |
| - "output_type": "stream", |
408 |
| - "text": [ |
409 |
| - "/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.\n" |
| 436 | + " Testing metrics\n", |
| 437 | + " -------------------------\n", |
| 438 | + "test/loss: 0.7276\n", |
| 439 | + "test/accuracy: 0.7895\n", |
| 440 | + "test/precision: 0.7750\n", |
| 441 | + "test/recall: 0.7115\n" |
410 | 442 | ]
|
411 |
| - }, |
412 |
| - { |
413 |
| - "data": { |
414 |
| - "text/html": [ |
415 |
| - "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
416 |
| - "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
417 |
| - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
418 |
| - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/accuracy </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7894737124443054 </span>│\n", |
419 |
| - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.6638099551200867 </span>│\n", |
420 |
| - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/precision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7749999761581421 </span>│\n", |
421 |
| - "│<span style=\"color: #008080; text-decoration-color: #008080\"> test/recall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7115384340286255 </span>│\n", |
422 |
| - "└───────────────────────────┴───────────────────────────┘\n", |
423 |
| - "</pre>\n" |
424 |
| - ], |
425 |
| - "text/plain": [ |
426 |
| - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
427 |
| - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
428 |
| - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
429 |
| - "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7894737124443054 \u001b[0m\u001b[35m \u001b[0m│\n", |
430 |
| - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6638099551200867 \u001b[0m\u001b[35m \u001b[0m│\n", |
431 |
| - "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7749999761581421 \u001b[0m\u001b[35m \u001b[0m│\n", |
432 |
| - "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7115384340286255 \u001b[0m\u001b[35m \u001b[0m│\n", |
433 |
| - "└───────────────────────────┴───────────────────────────┘\n" |
434 |
| - ] |
435 |
| - }, |
436 |
| - "metadata": {}, |
437 |
| - "output_type": "display_data" |
438 | 443 | }
|
439 | 444 | ],
|
440 | 445 | "source": [
|
441 |
| - "trainer.test(model, datamodule)\n", |
442 |
| - "test_metrics = trainer.callback_metrics" |
| 446 | + "print(' Testing metrics\\n', '-'*25)\n", |
| 447 | + "for key in test_metrics:\n", |
| 448 | + " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" |
443 | 449 | ]
|
444 |
| - }, |
445 |
| - { |
446 |
| - "cell_type": "code", |
447 |
| - "execution_count": null, |
448 |
| - "metadata": {}, |
449 |
| - "outputs": [], |
450 |
| - "source": [] |
451 | 450 | }
|
452 | 451 | ],
|
453 | 452 | "metadata": {
|
|
0 commit comments