You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+18-18Lines changed: 18 additions & 18 deletions
Original file line number
Diff line number
Diff line change
@@ -13,9 +13,9 @@ Torch takes care of our autograd needs. The documentation is available at https:
13
13
14
14
### Task 1: Denoising a cosine
15
15
16
-
To get a notion of how function learning of a dense layer network works on given data, we will first have a look at the example from the lecture. In the following task you will implement gradient descent learning of a dense neural network using `jax` and use it to learn a function, e.g. a cosine.
16
+
To get a notion of how function learning of a dense layer network works on given data, we will first have a look at the example from the lecture. In the following task you will implement gradient descent learning of a dense neural network using `torch` and use it to learn a function, e.g. a cosine.
17
17
18
-
- As a first step, create a cosine function in Jax and add some noise with `jax.random.normal`. Use, for example, a signal length of $n = 200$ samples and a period of your choosing. This will be the noisy signal that the model is supposed to learn the underlaying cosine from.
18
+
- As a first step, create a cosine function in torch and add some noise with `torch.randn`. Use, for example, a signal length of $n = 200$ samples and a period of your choosing. This will be the noisy signal that the model is supposed to learn the underlaying cosine from.
19
19
20
20
- Recall the definition of the sigmoid function $\sigma$
21
21
@@ -33,19 +33,19 @@ To get a notion of how function learning of a dense layer network works on given
33
33
```
34
34
where $\mathbf{W}_1\in \mathbb{R}^{m,n}, \mathbf{x}\in\mathbb{R}^n, \mathbf{b}\in\mathbb{R}^m$ and $m$ denotes the number of neurons and $n$ the input signal length. Suppose that the input parameters are stored in a [python dictonary](https://docs.python.org/3/tutorial/datastructures.html#dictionaries) with the keys `W_1`, `W_2` and `b`. Use numpys `@` notation for the matrix product.
35
35
36
-
- Use `jax.random.uniform` to initialize your weights. For a signal length of $200$ the $W_2$ matrix should have e.g. have the shape [200, `hidden_neurons`] and $W_1$ a shape of [`hidden_neurons`, 200]. Start with $\mathcal{U}[-0.1, 0.1]$ for example. `jax.random.PRNGKey` allows you to create a seed for the random number generator.
36
+
- Use `torch.randn` to initialize your weights. For a signal length of $200$ the $W_2$ matrix should have e.g. have the shape [200, `hidden_neurons`] and $W_1$ a shape of [`hidden_neurons`, 200].
-`**` denotes squares in Python, `jnp.sum` allows you to sum up all terms.
44
+
-`**` denotes squares in Python, `torch.sum` allows you to sum up all terms.
45
45
46
46
- Define the forward pass in the `net_cost` function. The forward pass evaluates the network and the cost function.
47
47
48
-
- Train your network to denoise a cosine. To do so, implement gradient descent on the noisy input signal and use e.g. `jax.value_and_grad` to compute cost and gradient at the same time. Remember the gradient descent update rule
48
+
- Train your network to denoise a cosine. To do so, implement gradient descent on the noisy input signal and use e.g. `torch.grad_and_value` to gradient and compute cost at the same time. Remember the gradient descent update rule
In this task we will go one step further. Instead of a cosine function, our neural network will learn how to identify handwritten digits from the [MNSIT dataset](http://yann.lecun.com/exdb/mnist/). For that, we will be using the [linen api](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html) of the [flax](https://flax.readthedocs.io/en/latest/) package. To get started familiarize yourself with the linen api to train a fully connected network in `src/mnist.py`. In this script, some functions are already implemented and can be reused. Use `jax.numpy.array_split` to create a list of batches from your training set. [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) is an elegant way to deal with data batches. This task aims to compute gradients and update steps for all batches in the list. If you are coding on bender the function `matplotlib.pyplot.show` doesn't work if you are not connected to the X server of bender. Use e.g. `plt.savefig` to save the figure and view it in vscode.
64
+
In this task we will go one step further. Instead of a cosine function, our neural network will learn how to identify handwritten digits from the [MNSIT dataset](http://yann.lecun.com/exdb/mnist/). For that, we will be using the [torch.nn](https://pytorch.org/docs/stable/nn.html) module. To get started familiarize yourself with the torch.nn to train a fully connected network in `src/mnist.py`. In this script, some functions are already implemented and can be reused. [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) is an elegant way to deal with data batches (Torch takes care of this for us). This task aims to compute gradients and update steps for all batches in the list. If you are coding on bender the function `matplotlib.pyplot.show` doesn't work if you are not connected to the X server of bender. Use e.g. `plt.savefig` to save the figure and view it in vscode.
65
65
66
-
- Implement the `normalize` function to ensure approximate standard-normal inputs. Make use of handy numpy methods that you already know. Normalization requires subtraction of the mean and division by the standard deviation with $i = 1, \dots w$ and $j = 1, \dots h$ with $w$ the image width and $h$ the image height and $k$ running through the batch dimension:
66
+
- Implement the `normalize_batch` function to ensure approximate standard-normal inputs. Make use of handy torch inbuilt methods. Normalization requires subtraction of the mean and division by the standard deviation with $i = 1, \dots w$ and $j = 1, \dots h$ with $w$ the image width and $h$ the image height and $k$ running through the batch dimension:
67
67
68
68
```math
69
69
\tilde{{x}}_{ijk} = \frac{x_{ijk} - \mu}{\sigma}
70
70
```
71
71
72
-
- The forward step requires the `Net` object from its [class](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module). It is your fully connected neural network model. Applying weights to a `flax.linen.Module` is comparable to calculating the forward pass of the network in task 1. Implement a dense network in `Net` of your choosing using a combination of `flax.linen.Dense` and `flax.linen.activation.relu` or `flax.linen.sigmoid`.
72
+
- The forward step requires the `Net` object from its [class](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#define-the-class). It is your fully connected neural network model. Implement a dense network in `Net` of your choosing using a combination of `torch.nn.Linear` and `th.nn.ReLU` or `th.nn.Sigmoid`
73
73
74
-
- The forward pass ends with the evaluation of a cost function. Write a `cross_entropy` cost function with $n_o$ the number of labels and $n_b$ in the batched case using
74
+
- In `Net` class additionally, implement the `forward` function to compute the network forwad pass.
75
+
76
+
- Write a `cross_entropy` cost function with $n_o$ the number of labels and $n_b$ in the batched case using
- If you have chosen to work with ten output neurons. Use `jax.nn.one_hot` to encode the labels.
81
-
82
-
- Now implement the `forward_step` function. Calculate the network output first. Then compute the loss. It should return a scalar cost term you can use to compute gradients. Make use of the cross entropy.
82
+
- If you have chosen to work with ten output neurons. Use `torch.nn.functional.one_hot` to encode the labels.
83
83
84
-
- Next we want to be able to do an optimization step with stochastic gradient descent (sgd). Implement `sgd_step`. Use the gradients to update the weights. Consider `jax.tree_util.tree_map` for this task. Treemaps work best with a lambda expression.
84
+
- Next we want to be able to do an optimization step with stochastic gradient descent (sgd). Implement `sgd_step`. One way to do this is to iterate over `model.parameters()` and update each parameter individually with its gradient. One can access the gradient for each parameter with `<param>.grad`.
85
85
86
-
- To evaluate the network we calculate the accuracy of the network output. Implement `get_acc` to calculate the accuracy given a batch of images and the corresponding labels for these images.
86
+
- To evaluate the network we calculate the accuracy of the network output. Implement `get_acc` to calculate the accuracy given a dataloader containing batches of images and corresponding labels. More about dataloaders is available [here](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).
87
87
88
-
- Now is the time to move back to the main procedure. First, the train data is fetched via the function `get_mnist_train_data`. To be able to evaluate the network while it is being trained, we use a validation set. Here the train set is split into two disjoint sets: the training and the validation set. Both sets must be normalized.
88
+
- Now is the time to move back to the main procedure. First, the train data is fetched via the torchvision `torchvision.MNIST`. To be able to evaluate the network while it is being trained, we use a validation set. Here the train set is split into two disjoint sets: the training and the validation set using `torch.utils.data.random_split`.
89
89
90
-
-Define your loss and gradient function with jax (see task 1). Next, initialize the network with the `Net` object (see the `flax` documentation for help).
90
+
-Initialize the network with the `Net` object (see the `torch` documentation for help).
91
91
92
-
- Train your network for a fixed number of `epochs` over the entire dataset.
92
+
- Train your network for a fixed number of `EPCOHS` over the entire dataset.
93
93
94
-
- Last, load the test data with `get_mnist_test_data` and calculate the test accuracy. Save it to a list.
94
+
- Last, load the test data with `test_loader` and calculate the test accuracy. Save it to a list.
95
95
96
96
- Optional: Plot the training and validation accuracies and add the test accuracy in the end.
0 commit comments