Skip to content

Commit 89489b7

Browse files
committed
Reformat code using Black
1 parent 8ea88d8 commit 89489b7

17 files changed

+584
-444
lines changed

Examples/Binarizer.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"\n",
2222
"import os\n",
2323
"\n",
24-
"os.chdir('..')"
24+
"os.chdir(\"..\")"
2525
]
2626
},
2727
{

Examples/GradientChecker.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{

Examples/LearningTree.ipynb

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{
@@ -108,12 +108,13 @@
108108
"source": [
109109
"import numpy as np\n",
110110
"from PIL import Image\n",
111-
"import matplotlib.pyplot as plt \n",
111+
"import matplotlib.pyplot as plt\n",
112+
"\n",
113+
"MODE_LUMINANCE = \"L\"\n",
114+
"MODE_RGB = \"RGB\"\n",
115+
"MODE_RGBA = \"RGBA\"\n",
116+
"COLORMAP_VIRIDIS = \"viridis\"\n",
112117
"\n",
113-
"MODE_LUMINANCE = 'L'\n",
114-
"MODE_RGB = 'RGB'\n",
115-
"MODE_RGBA = 'RGBA'\n",
116-
"COLORMAP_VIRIDIS = 'viridis'\n",
117118
"\n",
118119
"def visualize_array(array, colormap=None):\n",
119120
" \"\"\"\n",
@@ -132,18 +133,24 @@
132133
" array = np.clip(array, 0.0, 1.0)\n",
133134
" array = (array * 255).astype(np.uint8) # Convert to uint8 for PIL\n",
134135
" elif array.dtype != np.uint8:\n",
135-
" print(\"Warning: Array data type is not uint8 or float. Attempting to scale to uint8.\")\n",
136+
" print(\n",
137+
" \"Warning: Array data type is not uint8 or float. Attempting to scale to uint8.\"\n",
138+
" )\n",
136139
" try:\n",
137-
" array = ((array - array.min()) / (array.max() - array.min()) * 255).astype(np.uint8)\n",
140+
" array = ((array - array.min()) / (array.max() - array.min()) * 255).astype(\n",
141+
" np.uint8\n",
142+
" )\n",
138143
" except ValueError:\n",
139-
" print(\"Error: Array contains NaN values or all values are the same. Cannot normalize. Returning.\")\n",
144+
" print(\n",
145+
" \"Error: Array contains NaN values or all values are the same. Cannot normalize. Returning.\"\n",
146+
" )\n",
140147
" return\n",
141148
"\n",
142149
" if len(array.shape) == 2:\n",
143150
" mode = MODE_LUMINANCE\n",
144151
" if colormap:\n",
145152
" plt.imshow(array, cmap=colormap)\n",
146-
" plt.axis('off')\n",
153+
" plt.axis(\"off\")\n",
147154
" plt.show()\n",
148155
" return\n",
149156
" elif len(array.shape) == 3:\n",
@@ -152,14 +159,16 @@
152159
" elif array.shape[2] == 4: # RGBA\n",
153160
" mode = MODE_RGBA\n",
154161
" else:\n",
155-
" raise ValueError(\"Array must have 2 (grayscale) or 3/4 (RGB/RGBA) channels.\")\n",
162+
" raise ValueError(\n",
163+
" \"Array must have 2 (grayscale) or 3/4 (RGB/RGBA) channels.\"\n",
164+
" )\n",
156165
" else:\n",
157166
" raise ValueError(\"Array must be 2D or 3D.\")\n",
158167
"\n",
159168
" image = Image.fromarray(array, mode=mode)\n",
160169
"\n",
161170
" plt.imshow(image)\n",
162-
" plt.axis('off')\n",
171+
" plt.axis(\"off\")\n",
163172
" plt.show()\n",
164173
"\n",
165174
"\n",
@@ -178,17 +187,17 @@
178187
"visualize_array(rgba_array)\n",
179188
"\n",
180189
"# 4. Float array example\n",
181-
"float_array = np.random.rand(50, 100) # values from 0.0 to 1.0\n",
190+
"float_array = np.random.rand(50, 100) # values from 0.0 to 1.0\n",
182191
"visualize_array(float_array)\n",
183192
"\n",
184193
"# 5. Float array with larger values\n",
185194
"float_array2 = np.random.rand(50, 100) * 255.0 # Values from 0.0 to 255.0\n",
186-
"visualize_array(float_array2) # Will need to be scaled.\n",
195+
"visualize_array(float_array2) # Will need to be scaled.\n",
187196
"\n",
188197
"# 6. Array with NaN values\n",
189198
"nan_array = np.random.rand(50, 100)\n",
190-
"nan_array[0,0] = np.nan #Introduce NaN\n",
191-
"visualize_array(nan_array) #Handles the nan value and prints an error message."
199+
"nan_array[0, 0] = np.nan # Introduce NaN\n",
200+
"visualize_array(nan_array) # Handles the nan value and prints an error message."
192201
]
193202
},
194203
{
@@ -760,7 +769,7 @@
760769
}
761770
],
762771
"source": [
763-
"print('Target:', dataset.target)\n",
772+
"print(\"Target:\", dataset.target)\n",
764773
"train_df = dataset.data.train_df\n",
765774
"test_df = dataset.data.test_df\n",
766775
"test_target = test_df[dataset.target]\n",
@@ -1228,7 +1237,7 @@
12281237
"outputs": [],
12291238
"source": [
12301239
"def percent(fraction: float) -> str:\n",
1231-
" return f'{fraction * 100:.2f}%'"
1240+
" return f\"{fraction * 100:.2f}%\""
12321241
]
12331242
},
12341243
{
@@ -1266,10 +1275,10 @@
12661275
"missclassified_count = diff[diff != 0].count()\n",
12671276
"correct_count = len(y_pred) - missclassified_count\n",
12681277
"\n",
1269-
"print(f'Total: {correct_count}/{len(y_pred)} ({percent(score)})')\n",
1270-
"print(f'Predicted values:\\n{y_pred}')\n",
1271-
"print(f'Actual values:\\n{test_target.values}')\n",
1272-
"print(f'Diff: {diff.values}')"
1278+
"print(f\"Total: {correct_count}/{len(y_pred)} ({percent(score)})\")\n",
1279+
"print(f\"Predicted values:\\n{y_pred}\")\n",
1280+
"print(f\"Actual values:\\n{test_target.values}\")\n",
1281+
"print(f\"Diff: {diff.values}\")"
12731282
]
12741283
},
12751284
{
@@ -2174,22 +2183,21 @@
21742183
}
21752184
],
21762185
"source": [
2177-
"\n",
21782186
"import graphviz\n",
21792187
"\n",
21802188
"graphviz_source = tree.export_graphviz(\n",
21812189
" decision_tree=model,\n",
21822190
" feature_names=dataset.data.df.columns.drop(dataset.target),\n",
2183-
" class_names=['foo', 'bar', 'baz', 'qux'],\n",
2191+
" class_names=[\"foo\", \"bar\", \"baz\", \"qux\"],\n",
21842192
" filled=True,\n",
21852193
" leaves_parallel=True,\n",
21862194
" node_ids=True,\n",
21872195
" proportion=True,\n",
21882196
" # rotate=True,\n",
21892197
" rounded=True,\n",
21902198
" special_characters=True,\n",
2191-
" precision=6, # digits after point\n",
2192-
" fontname='Times New Roman',\n",
2199+
" precision=6, # digits after point\n",
2200+
" fontname=\"Times New Roman\",\n",
21932201
" # fontname='Big Caslon',\n",
21942202
" # fontname='Brush Script MT',\n",
21952203
")\n",
@@ -2270,7 +2278,7 @@
22702278
" ),\n",
22712279
")\n",
22722280
"\n",
2273-
"plot_list(losses, y_label='Loss')"
2281+
"plot_list(losses, y_label=\"Loss\")"
22742282
]
22752283
},
22762284
{
@@ -2292,6 +2300,7 @@
22922300
"source": [
22932301
"import torch\n",
22942302
"\n",
2303+
"\n",
22952304
"def score(model, dataloader, loss_fn, metric=\"accuracy\", device=\"cpu\"):\n",
22962305
" \"\"\"\n",
22972306
" Calculates a score for a PyTorch model on a given dataset.\n",
@@ -2320,7 +2329,7 @@
23202329
"\n",
23212330
" outputs = model(inputs)\n",
23222331
" loss = loss_fn(outputs, labels)\n",
2323-
" total_loss += loss.item() * inputs.size(0) # weighted by batch size\n",
2332+
" total_loss += loss.item() * inputs.size(0) # weighted by batch size\n",
23242333
"\n",
23252334
" if metric == \"accuracy\":\n",
23262335
" _, predicted = torch.max(outputs.data, 1) # Get predicted class\n",
@@ -2339,12 +2348,17 @@
23392348
" elif metric == \"loss\":\n",
23402349
" return avg_loss\n",
23412350
" elif callable(metric):\n",
2342-
" accuracy = correct / total_samples\n",
2343-
" return accuracy\n",
2351+
" accuracy = correct / total_samples\n",
2352+
" return accuracy\n",
23442353
" else:\n",
2345-
" raise ValueError(f\"Unsupported metric: {metric}. Choose 'accuracy', 'loss', or a callable.\")\n",
2354+
" raise ValueError(\n",
2355+
" f\"Unsupported metric: {metric}. Choose 'accuracy', 'loss', or a callable.\"\n",
2356+
" )\n",
2357+
"\n",
23462358
"\n",
2347-
"score(model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\")"
2359+
"score(\n",
2360+
" model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\"\n",
2361+
")"
23482362
]
23492363
},
23502364
{
@@ -2427,7 +2441,7 @@
24272441
" ),\n",
24282442
")\n",
24292443
"\n",
2430-
"plot_list(losses, y_label='Loss')"
2444+
"plot_list(losses, y_label=\"Loss\")"
24312445
]
24322446
},
24332447
{
@@ -2447,7 +2461,9 @@
24472461
}
24482462
],
24492463
"source": [
2450-
"score(model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\")"
2464+
"score(\n",
2465+
" model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\"\n",
2466+
")"
24512467
]
24522468
},
24532469
{

Examples/LogisticRegression.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"\n",
2020
"import os\n",
2121
"\n",
22-
"os.chdir('..')"
22+
"os.chdir(\"..\")"
2323
]
2424
},
2525
{
@@ -92,16 +92,20 @@
9292
"k = 1\n",
9393
"j = 1\n",
9494
"\n",
95+
"\n",
9596
"def sigmoid(z):\n",
9697
" return 1 / (1 + np.exp(-z))\n",
9798
"\n",
99+
"\n",
98100
"def combine(x, y):\n",
99101
" return b + k * x + j * y\n",
100102
"\n",
103+
"\n",
101104
"def predict(x, y):\n",
102105
" print(x, y)\n",
103106
" return sigmoid(combine(x, y))\n",
104107
"\n",
108+
"\n",
105109
"i = 9\n",
106110
"predict(X3[i], y3[i])"
107111
]

Examples/Means.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{
@@ -39,11 +39,11 @@
3939
"data = [mine, 46.0, 47.2, 46.6, 47.3, 45.8, their]\n",
4040
"mean = np.mean(data)\n",
4141
"std = np.std(data)\n",
42-
"print('mean = ', mean)\n",
43-
"print('|mean - mine| =', np.abs(np.mean(data) - mine))\n",
44-
"print('std = ', std)\n",
45-
"print('3 * std = ', 3 * std)\n",
46-
"print('3 * std / (mean - mine) = ', 3 * std / (mean - mine))\n"
42+
"print(\"mean = \", mean)\n",
43+
"print(\"|mean - mine| =\", np.abs(np.mean(data) - mine))\n",
44+
"print(\"std = \", std)\n",
45+
"print(\"3 * std = \", 3 * std)\n",
46+
"print(\"3 * std / (mean - mine) = \", 3 * std / (mean - mine))"
4747
]
4848
},
4949
{

Examples/MoreLayers.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{
@@ -248,7 +248,7 @@
248248
"from cgtnnlib.nn.AugmentedReLUNetworkMultilayer import AugmentedReLUNetworkMultilayer\n",
249249
"from cgtnnlib.datasets import datasets\n",
250250
"\n",
251-
"dataset = datasets['wine_quality']\n",
251+
"dataset = datasets[\"wine_quality\"]\n",
252252
"\n",
253253
"model = AugmentedReLUNetworkMultilayer(\n",
254254
" inputs_count=dataset.features_count,\n",
@@ -289,14 +289,18 @@
289289
"def compose(f, g):\n",
290290
" def wrapper(*args, **kwargs):\n",
291291
" return f(g(*args, **kwargs))\n",
292+
"\n",
292293
" return wrapper\n",
293294
"\n",
295+
"\n",
294296
"def a(x):\n",
295297
" return x * 2\n",
296298
"\n",
299+
"\n",
297300
"def b(x):\n",
298301
" return x / 3\n",
299302
"\n",
303+
"\n",
300304
"compose(b, a)(4)"
301305
]
302306
},

Examples/StableNoiseTest.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{
@@ -147,26 +147,31 @@
147147
"\n",
148148
"from cgtnnlib.NoiseGenerator import stable_noise_func\n",
149149
"\n",
150+
"\n",
150151
"def plot_histogram_pdf(data, title=\"Histogram PDF\", bins=10):\n",
151152
" \"\"\"Plots a histogram of the data, normalized to approximate the PDF.\"\"\"\n",
152-
" plt.hist(data, bins=bins, density=True, alpha=0.6, color='skyblue') # density=True normalizes\n",
153+
" plt.hist(\n",
154+
" data, bins=bins, density=True, alpha=0.6, color=\"skyblue\"\n",
155+
" ) # density=True normalizes\n",
153156
" plt.title(title)\n",
154157
" plt.xlabel(\"Value\")\n",
155158
" plt.ylabel(\"Probability Density\")\n",
156159
" plt.show()\n",
157160
"\n",
161+
"\n",
158162
"def display_stats(data, title: str):\n",
159163
" print(title)\n",
160-
" print('Mean:', np.mean(data))\n",
161-
" print('Stddev:', np.std(data))\n",
164+
" print(\"Mean:\", np.mean(data))\n",
165+
" print(\"Stddev:\", np.std(data))\n",
162166
" plot_histogram_pdf(data, title=title)\n",
163167
"\n",
168+
"\n",
164169
"data = np.random.normal(size=2000)\n",
165-
"display_stats(data, 'Normal distribution')\n",
170+
"display_stats(data, \"Normal distribution\")\n",
166171
"\n",
167172
"for alpha in [1, 1.25, 1.5, 1.75, 2]:\n",
168173
" data = stable_noise_func(alpha=2, beta=1, size=2000)\n",
169-
" display_stats(data, f'Stable distribution \\\\alpha = {alpha}')"
174+
" display_stats(data, f\"Stable distribution \\\\alpha = {alpha}\")"
170175
]
171176
}
172177
],

Examples/Statsmodel.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"import os\n",
1313
"\n",
14-
"os.chdir('..')"
14+
"os.chdir(\"..\")"
1515
]
1616
},
1717
{

0 commit comments

Comments
 (0)