|
11 | 11 | "\n",
|
12 | 12 | "import os\n",
|
13 | 13 | "\n",
|
14 |
| - "os.chdir('..')" |
| 14 | + "os.chdir(\"..\")" |
15 | 15 | ]
|
16 | 16 | },
|
17 | 17 | {
|
|
108 | 108 | "source": [
|
109 | 109 | "import numpy as np\n",
|
110 | 110 | "from PIL import Image\n",
|
111 |
| - "import matplotlib.pyplot as plt \n", |
| 111 | + "import matplotlib.pyplot as plt\n", |
| 112 | + "\n", |
| 113 | + "MODE_LUMINANCE = \"L\"\n", |
| 114 | + "MODE_RGB = \"RGB\"\n", |
| 115 | + "MODE_RGBA = \"RGBA\"\n", |
| 116 | + "COLORMAP_VIRIDIS = \"viridis\"\n", |
112 | 117 | "\n",
|
113 |
| - "MODE_LUMINANCE = 'L'\n", |
114 |
| - "MODE_RGB = 'RGB'\n", |
115 |
| - "MODE_RGBA = 'RGBA'\n", |
116 |
| - "COLORMAP_VIRIDIS = 'viridis'\n", |
117 | 118 | "\n",
|
118 | 119 | "def visualize_array(array, colormap=None):\n",
|
119 | 120 | " \"\"\"\n",
|
|
132 | 133 | " array = np.clip(array, 0.0, 1.0)\n",
|
133 | 134 | " array = (array * 255).astype(np.uint8) # Convert to uint8 for PIL\n",
|
134 | 135 | " elif array.dtype != np.uint8:\n",
|
135 |
| - " print(\"Warning: Array data type is not uint8 or float. Attempting to scale to uint8.\")\n", |
| 136 | + " print(\n", |
| 137 | + " \"Warning: Array data type is not uint8 or float. Attempting to scale to uint8.\"\n", |
| 138 | + " )\n", |
136 | 139 | " try:\n",
|
137 |
| - " array = ((array - array.min()) / (array.max() - array.min()) * 255).astype(np.uint8)\n", |
| 140 | + " array = ((array - array.min()) / (array.max() - array.min()) * 255).astype(\n", |
| 141 | + " np.uint8\n", |
| 142 | + " )\n", |
138 | 143 | " except ValueError:\n",
|
139 |
| - " print(\"Error: Array contains NaN values or all values are the same. Cannot normalize. Returning.\")\n", |
| 144 | + " print(\n", |
| 145 | + " \"Error: Array contains NaN values or all values are the same. Cannot normalize. Returning.\"\n", |
| 146 | + " )\n", |
140 | 147 | " return\n",
|
141 | 148 | "\n",
|
142 | 149 | " if len(array.shape) == 2:\n",
|
143 | 150 | " mode = MODE_LUMINANCE\n",
|
144 | 151 | " if colormap:\n",
|
145 | 152 | " plt.imshow(array, cmap=colormap)\n",
|
146 |
| - " plt.axis('off')\n", |
| 153 | + " plt.axis(\"off\")\n", |
147 | 154 | " plt.show()\n",
|
148 | 155 | " return\n",
|
149 | 156 | " elif len(array.shape) == 3:\n",
|
|
152 | 159 | " elif array.shape[2] == 4: # RGBA\n",
|
153 | 160 | " mode = MODE_RGBA\n",
|
154 | 161 | " else:\n",
|
155 |
| - " raise ValueError(\"Array must have 2 (grayscale) or 3/4 (RGB/RGBA) channels.\")\n", |
| 162 | + " raise ValueError(\n", |
| 163 | + " \"Array must have 2 (grayscale) or 3/4 (RGB/RGBA) channels.\"\n", |
| 164 | + " )\n", |
156 | 165 | " else:\n",
|
157 | 166 | " raise ValueError(\"Array must be 2D or 3D.\")\n",
|
158 | 167 | "\n",
|
159 | 168 | " image = Image.fromarray(array, mode=mode)\n",
|
160 | 169 | "\n",
|
161 | 170 | " plt.imshow(image)\n",
|
162 |
| - " plt.axis('off')\n", |
| 171 | + " plt.axis(\"off\")\n", |
163 | 172 | " plt.show()\n",
|
164 | 173 | "\n",
|
165 | 174 | "\n",
|
|
178 | 187 | "visualize_array(rgba_array)\n",
|
179 | 188 | "\n",
|
180 | 189 | "# 4. Float array example\n",
|
181 |
| - "float_array = np.random.rand(50, 100) # values from 0.0 to 1.0\n", |
| 190 | + "float_array = np.random.rand(50, 100) # values from 0.0 to 1.0\n", |
182 | 191 | "visualize_array(float_array)\n",
|
183 | 192 | "\n",
|
184 | 193 | "# 5. Float array with larger values\n",
|
185 | 194 | "float_array2 = np.random.rand(50, 100) * 255.0 # Values from 0.0 to 255.0\n",
|
186 |
| - "visualize_array(float_array2) # Will need to be scaled.\n", |
| 195 | + "visualize_array(float_array2) # Will need to be scaled.\n", |
187 | 196 | "\n",
|
188 | 197 | "# 6. Array with NaN values\n",
|
189 | 198 | "nan_array = np.random.rand(50, 100)\n",
|
190 |
| - "nan_array[0,0] = np.nan #Introduce NaN\n", |
191 |
| - "visualize_array(nan_array) #Handles the nan value and prints an error message." |
| 199 | + "nan_array[0, 0] = np.nan # Introduce NaN\n", |
| 200 | + "visualize_array(nan_array) # Handles the nan value and prints an error message." |
192 | 201 | ]
|
193 | 202 | },
|
194 | 203 | {
|
|
760 | 769 | }
|
761 | 770 | ],
|
762 | 771 | "source": [
|
763 |
| - "print('Target:', dataset.target)\n", |
| 772 | + "print(\"Target:\", dataset.target)\n", |
764 | 773 | "train_df = dataset.data.train_df\n",
|
765 | 774 | "test_df = dataset.data.test_df\n",
|
766 | 775 | "test_target = test_df[dataset.target]\n",
|
|
1228 | 1237 | "outputs": [],
|
1229 | 1238 | "source": [
|
1230 | 1239 | "def percent(fraction: float) -> str:\n",
|
1231 |
| - " return f'{fraction * 100:.2f}%'" |
| 1240 | + " return f\"{fraction * 100:.2f}%\"" |
1232 | 1241 | ]
|
1233 | 1242 | },
|
1234 | 1243 | {
|
|
1266 | 1275 | "missclassified_count = diff[diff != 0].count()\n",
|
1267 | 1276 | "correct_count = len(y_pred) - missclassified_count\n",
|
1268 | 1277 | "\n",
|
1269 |
| - "print(f'Total: {correct_count}/{len(y_pred)} ({percent(score)})')\n", |
1270 |
| - "print(f'Predicted values:\\n{y_pred}')\n", |
1271 |
| - "print(f'Actual values:\\n{test_target.values}')\n", |
1272 |
| - "print(f'Diff: {diff.values}')" |
| 1278 | + "print(f\"Total: {correct_count}/{len(y_pred)} ({percent(score)})\")\n", |
| 1279 | + "print(f\"Predicted values:\\n{y_pred}\")\n", |
| 1280 | + "print(f\"Actual values:\\n{test_target.values}\")\n", |
| 1281 | + "print(f\"Diff: {diff.values}\")" |
1273 | 1282 | ]
|
1274 | 1283 | },
|
1275 | 1284 | {
|
|
2174 | 2183 | }
|
2175 | 2184 | ],
|
2176 | 2185 | "source": [
|
2177 |
| - "\n", |
2178 | 2186 | "import graphviz\n",
|
2179 | 2187 | "\n",
|
2180 | 2188 | "graphviz_source = tree.export_graphviz(\n",
|
2181 | 2189 | " decision_tree=model,\n",
|
2182 | 2190 | " feature_names=dataset.data.df.columns.drop(dataset.target),\n",
|
2183 |
| - " class_names=['foo', 'bar', 'baz', 'qux'],\n", |
| 2191 | + " class_names=[\"foo\", \"bar\", \"baz\", \"qux\"],\n", |
2184 | 2192 | " filled=True,\n",
|
2185 | 2193 | " leaves_parallel=True,\n",
|
2186 | 2194 | " node_ids=True,\n",
|
2187 | 2195 | " proportion=True,\n",
|
2188 | 2196 | " # rotate=True,\n",
|
2189 | 2197 | " rounded=True,\n",
|
2190 | 2198 | " special_characters=True,\n",
|
2191 |
| - " precision=6, # digits after point\n", |
2192 |
| - " fontname='Times New Roman',\n", |
| 2199 | + " precision=6, # digits after point\n", |
| 2200 | + " fontname=\"Times New Roman\",\n", |
2193 | 2201 | " # fontname='Big Caslon',\n",
|
2194 | 2202 | " # fontname='Brush Script MT',\n",
|
2195 | 2203 | ")\n",
|
|
2270 | 2278 | " ),\n",
|
2271 | 2279 | ")\n",
|
2272 | 2280 | "\n",
|
2273 |
| - "plot_list(losses, y_label='Loss')" |
| 2281 | + "plot_list(losses, y_label=\"Loss\")" |
2274 | 2282 | ]
|
2275 | 2283 | },
|
2276 | 2284 | {
|
|
2292 | 2300 | "source": [
|
2293 | 2301 | "import torch\n",
|
2294 | 2302 | "\n",
|
| 2303 | + "\n", |
2295 | 2304 | "def score(model, dataloader, loss_fn, metric=\"accuracy\", device=\"cpu\"):\n",
|
2296 | 2305 | " \"\"\"\n",
|
2297 | 2306 | " Calculates a score for a PyTorch model on a given dataset.\n",
|
|
2320 | 2329 | "\n",
|
2321 | 2330 | " outputs = model(inputs)\n",
|
2322 | 2331 | " loss = loss_fn(outputs, labels)\n",
|
2323 |
| - " total_loss += loss.item() * inputs.size(0) # weighted by batch size\n", |
| 2332 | + " total_loss += loss.item() * inputs.size(0) # weighted by batch size\n", |
2324 | 2333 | "\n",
|
2325 | 2334 | " if metric == \"accuracy\":\n",
|
2326 | 2335 | " _, predicted = torch.max(outputs.data, 1) # Get predicted class\n",
|
|
2339 | 2348 | " elif metric == \"loss\":\n",
|
2340 | 2349 | " return avg_loss\n",
|
2341 | 2350 | " elif callable(metric):\n",
|
2342 |
| - " accuracy = correct / total_samples\n", |
2343 |
| - " return accuracy\n", |
| 2351 | + " accuracy = correct / total_samples\n", |
| 2352 | + " return accuracy\n", |
2344 | 2353 | " else:\n",
|
2345 |
| - " raise ValueError(f\"Unsupported metric: {metric}. Choose 'accuracy', 'loss', or a callable.\")\n", |
| 2354 | + " raise ValueError(\n", |
| 2355 | + " f\"Unsupported metric: {metric}. Choose 'accuracy', 'loss', or a callable.\"\n", |
| 2356 | + " )\n", |
| 2357 | + "\n", |
2346 | 2358 | "\n",
|
2347 |
| - "score(model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\")" |
| 2359 | + "score(\n", |
| 2360 | + " model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\"\n", |
| 2361 | + ")" |
2348 | 2362 | ]
|
2349 | 2363 | },
|
2350 | 2364 | {
|
|
2427 | 2441 | " ),\n",
|
2428 | 2442 | ")\n",
|
2429 | 2443 | "\n",
|
2430 |
| - "plot_list(losses, y_label='Loss')" |
| 2444 | + "plot_list(losses, y_label=\"Loss\")" |
2431 | 2445 | ]
|
2432 | 2446 | },
|
2433 | 2447 | {
|
|
2447 | 2461 | }
|
2448 | 2462 | ],
|
2449 | 2463 | "source": [
|
2450 |
| - "score(model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\")" |
| 2464 | + "score(\n", |
| 2465 | + " model, dataset.data.test_loader, dataset.learning_task.criterion, metric=\"accuracy\"\n", |
| 2466 | + ")" |
2451 | 2467 | ]
|
2452 | 2468 | },
|
2453 | 2469 | {
|
|
0 commit comments