Skip to content

Commit 23d3a58

Browse files
committed
Deploying to main from @ geometric-intelligence/TopoBench@0579351 🚀
1 parent eb5b540 commit 23d3a58

File tree

108 files changed

+1153
-1124
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+1153
-1124
lines changed
-64.7 KB
Binary file not shown.

topobenchmarkx/.doctrees/nbsphinx/notebooks/tutorial_dataset.ipynb

+87-91
Large diffs are not rendered by default.

topobenchmarkx/.doctrees/nbsphinx/notebooks/tutorial_lifting.ipynb

+65-66
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,27 @@
2222
"cell_type": "markdown",
2323
"metadata": {},
2424
"source": [
25-
"## Imports"
25+
"### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n",
26+
"&emsp;[1. Imports](##sec1)\n",
27+
"\n",
28+
"&emsp;[2. Configurations and utilities](##sec2)\n",
29+
"\n",
30+
"&emsp;[3. Defining the lifting](##sec2)\n",
31+
"\n",
32+
"&emsp;[4. Loading the data](##sec3)\n",
33+
"\n",
34+
"&emsp;[5. Model initialization](##sec4)\n",
35+
"\n",
36+
"&emsp;[6. Training](##sec5)\n",
37+
"\n",
38+
"&emsp;[7. Testing the model](##sec6)"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"metadata": {},
44+
"source": [
45+
"## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>"
2646
]
2747
},
2848
{
@@ -60,7 +80,7 @@
6080
"cell_type": "markdown",
6181
"metadata": {},
6282
"source": [
63-
"## Configurations and utilities"
83+
"## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>"
6484
]
6585
},
6686
{
@@ -152,7 +172,7 @@
152172
"cell_type": "markdown",
153173
"metadata": {},
154174
"source": [
155-
"## Defining the lifting"
175+
"## 3. Defining the lifting <a class=\"anchor\" id=\"sec3\"></a>"
156176
]
157177
},
158178
{
@@ -206,7 +226,7 @@
206226
"cell_type": "markdown",
207227
"metadata": {},
208228
"source": [
209-
"## Loading the data"
229+
"## 4. Loading the data <a class=\"anchor\" id=\"sec4\"></a>"
210230
]
211231
},
212232
{
@@ -233,11 +253,10 @@
233253
"metadata": {},
234254
"outputs": [
235255
{
236-
"name": "stderr",
256+
"name": "stdout",
237257
"output_type": "stream",
238258
"text": [
239-
"Processing...\n",
240-
"Done!\n"
259+
"Transform parameters are the same, using existing data_dir: ./data/MUTAG/MUTAG/clique_lifting/458544608\n"
241260
]
242261
}
243262
],
@@ -255,7 +274,7 @@
255274
"cell_type": "markdown",
256275
"metadata": {},
257276
"source": [
258-
"## Model initialization"
277+
"## 5. Model initialization <a class=\"anchor\" id=\"sec5\"></a>"
259278
]
260279
},
261280
{
@@ -304,7 +323,7 @@
304323
"cell_type": "markdown",
305324
"metadata": {},
306325
"source": [
307-
"## Training"
326+
"## 6. Training <a class=\"anchor\" id=\"sec6\"></a>"
308327
]
309328
},
310329
{
@@ -316,14 +335,14 @@
316335
},
317336
{
318337
"cell_type": "code",
319-
"execution_count": 10,
338+
"execution_count": 9,
320339
"metadata": {},
321340
"outputs": [
322341
{
323342
"name": "stderr",
324343
"output_type": "stream",
325344
"text": [
326-
"GPU available: True (mps), used: False\n",
345+
"GPU available: True (cuda), used: False\n",
327346
"TPU available: False, using: 0 TPU cores\n",
328347
"IPU available: False, using: 0 IPUs\n",
329348
"HPU available: False, using: 0 HPUs\n",
@@ -344,6 +363,7 @@
344363
}
345364
],
346365
"source": [
366+
"%%capture\n",
347367
"# Increase the number of epochs to get better results\n",
348368
"trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n",
349369
"\n",
@@ -353,34 +373,37 @@
353373
},
354374
{
355375
"cell_type": "code",
356-
"execution_count": 11,
376+
"execution_count": 10,
357377
"metadata": {},
358378
"outputs": [
359379
{
360380
"name": "stdout",
361381
"output_type": "stream",
362382
"text": [
363-
"train/accuracy : 0.7573964595794678\n",
364-
"train/precision : 0.737596869468689\n",
365-
"train/recall : 0.6920425891876221\n",
366-
"val/loss : 0.6638099551200867\n",
367-
"val/accuracy : 0.7894737124443054\n",
368-
"val/precision : 0.7749999761581421\n",
369-
"val/recall : 0.7115384340286255\n",
370-
"train/loss : 0.5858408808708191\n"
383+
" Training metrics\n",
384+
" --------------------------\n",
385+
"train/accuracy: 0.7633\n",
386+
"train/precision: 0.7352\n",
387+
"train/recall: 0.7310\n",
388+
"val/loss: 0.7276\n",
389+
"val/accuracy: 0.7895\n",
390+
"val/precision: 0.7750\n",
391+
"val/recall: 0.7115\n",
392+
"train/loss: 0.7212\n"
371393
]
372394
}
373395
],
374396
"source": [
397+
"print(' Training metrics\\n', '-'*26)\n",
375398
"for key in train_metrics:\n",
376-
" print(key,\": \", train_metrics[key].item())"
399+
" print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
377400
]
378401
},
379402
{
380403
"cell_type": "markdown",
381404
"metadata": {},
382405
"source": [
383-
"## Testing the model"
406+
"## 7. Testing the model <a class=\"anchor\" id=\"sec7\"></a>"
384407
]
385408
},
386409
{
@@ -390,6 +413,17 @@
390413
"Finally, we can test the model and obtain the results."
391414
]
392415
},
416+
{
417+
"cell_type": "code",
418+
"execution_count": 11,
419+
"metadata": {},
420+
"outputs": [],
421+
"source": [
422+
"%%capture\n",
423+
"trainer.test(model, datamodule)\n",
424+
"test_metrics = trainer.callback_metrics"
425+
]
426+
},
393427
{
394428
"cell_type": "code",
395429
"execution_count": 12,
@@ -399,55 +433,20 @@
399433
"name": "stdout",
400434
"output_type": "stream",
401435
"text": [
402-
"\n"
403-
]
404-
},
405-
{
406-
"name": "stderr",
407-
"output_type": "stream",
408-
"text": [
409-
"/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.\n"
436+
" Testing metrics\n",
437+
" -------------------------\n",
438+
"test/loss: 0.7276\n",
439+
"test/accuracy: 0.7895\n",
440+
"test/precision: 0.7750\n",
441+
"test/recall: 0.7115\n"
410442
]
411-
},
412-
{
413-
"data": {
414-
"text/html": [
415-
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
416-
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
417-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
418-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/accuracy </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7894737124443054 </span>│\n",
419-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.6638099551200867 </span>│\n",
420-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/precision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7749999761581421 </span>│\n",
421-
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/recall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.7115384340286255 </span>│\n",
422-
"└───────────────────────────┴───────────────────────────┘\n",
423-
"</pre>\n"
424-
],
425-
"text/plain": [
426-
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
427-
"\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
428-
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
429-
"\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7894737124443054 \u001b[0m\u001b[35m \u001b[0m│\n",
430-
"\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6638099551200867 \u001b[0m\u001b[35m \u001b[0m│\n",
431-
"\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7749999761581421 \u001b[0m\u001b[35m \u001b[0m│\n",
432-
"\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7115384340286255 \u001b[0m\u001b[35m \u001b[0m│\n",
433-
"└───────────────────────────┴───────────────────────────┘\n"
434-
]
435-
},
436-
"metadata": {},
437-
"output_type": "display_data"
438443
}
439444
],
440445
"source": [
441-
"trainer.test(model, datamodule)\n",
442-
"test_metrics = trainer.callback_metrics"
446+
"print(' Testing metrics\\n', '-'*25)\n",
447+
"for key in test_metrics:\n",
448+
" print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
443449
]
444-
},
445-
{
446-
"cell_type": "code",
447-
"execution_count": null,
448-
"metadata": {},
449-
"outputs": [],
450-
"source": []
451450
}
452451
],
453452
"metadata": {

0 commit comments

Comments
 (0)