Skip to content

Commit 91c7597

Browse files
authored
Update README.md
1 parent b39c384 commit 91c7597

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

README.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ Torch takes care of our autograd needs. The documentation is available at https:
1313

1414
### Task 1: Denoising a cosine
1515

16-
To get a notion of how function learning of a dense layer network works on given data, we will first have a look at the example from the lecture. In the following task you will implement gradient descent learning of a dense neural network using `jax` and use it to learn a function, e.g. a cosine.
16+
To get a notion of how function learning of a dense layer network works on given data, we will first have a look at the example from the lecture. In the following task you will implement gradient descent learning of a dense neural network using `torch` and use it to learn a function, e.g. a cosine.
1717

18-
- As a first step, create a cosine function in Jax and add some noise with `jax.random.normal`. Use, for example, a signal length of $n = 200$ samples and a period of your choosing. This will be the noisy signal that the model is supposed to learn the underlaying cosine from.
18+
- As a first step, create a cosine function in torch and add some noise with `torch.randn`. Use, for example, a signal length of $n = 200$ samples and a period of your choosing. This will be the noisy signal that the model is supposed to learn the underlaying cosine from.
1919

2020
- Recall the definition of the sigmoid function $\sigma$
2121

@@ -33,19 +33,19 @@ To get a notion of how function learning of a dense layer network works on given
3333
```
3434
where $\mathbf{W}_1\in \mathbb{R}^{m,n}, \mathbf{x}\in\mathbb{R}^n, \mathbf{b}\in\mathbb{R}^m$ and $m$ denotes the number of neurons and $n$ the input signal length. Suppose that the input parameters are stored in a [python dictonary](https://docs.python.org/3/tutorial/datastructures.html#dictionaries) with the keys `W_1`, `W_2` and `b`. Use numpys `@` notation for the matrix product.
3535

36-
- Use `jax.random.uniform` to initialize your weights. For a signal length of $200$ the $W_2$ matrix should have e.g. have the shape [200, `hidden_neurons`] and $W_1$ a shape of [`hidden_neurons`, 200]. Start with $\mathcal{U}[-0.1, 0.1]$ for example. `jax.random.PRNGKey` allows you to create a seed for the random number generator.
36+
- Use `torch.randn` to initialize your weights. For a signal length of $200$ the $W_2$ matrix should have e.g. have the shape [200, `hidden_neurons`] and $W_1$ a shape of [`hidden_neurons`, 200].
3737

3838
- Implement and test a squared error cost
3939

4040
```math
4141
C_{\text{se}} = \frac{1}{2} \sum_{k=1}^{n} (\mathbf{y}_k - \mathbf{o}_k)^2
4242
```
4343

44-
- `**` denotes squares in Python, `jnp.sum` allows you to sum up all terms.
44+
- `**` denotes squares in Python, `torch.sum` allows you to sum up all terms.
4545

4646
- Define the forward pass in the `net_cost` function. The forward pass evaluates the network and the cost function.
4747

48-
- Train your network to denoise a cosine. To do so, implement gradient descent on the noisy input signal and use e.g. `jax.value_and_grad` to compute cost and gradient at the same time. Remember the gradient descent update rule
48+
- Train your network to denoise a cosine. To do so, implement gradient descent on the noisy input signal and use e.g. `torch.grad_and_value` to gradient and compute cost at the same time. Remember the gradient descent update rule
4949

5050
```math
5151
\mathbf{W}_{\tau + 1} = \mathbf{W}_\tau - \epsilon \cdot \delta\mathbf{W}_{\tau}.
@@ -61,36 +61,36 @@ C_{\text{se}} = \frac{1}{2} \sum_{k=1}^{n} (\mathbf{y}_k - \mathbf{o}_k)^2
6161

6262

6363
### Task 2: MNIST
64-
In this task we will go one step further. Instead of a cosine function, our neural network will learn how to identify handwritten digits from the [MNSIT dataset](http://yann.lecun.com/exdb/mnist/). For that, we will be using the [linen api](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html) of the [flax](https://flax.readthedocs.io/en/latest/) package. To get started familiarize yourself with the linen api to train a fully connected network in `src/mnist.py`. In this script, some functions are already implemented and can be reused. Use `jax.numpy.array_split` to create a list of batches from your training set. [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) is an elegant way to deal with data batches. This task aims to compute gradients and update steps for all batches in the list. If you are coding on bender the function `matplotlib.pyplot.show` doesn't work if you are not connected to the X server of bender. Use e.g. `plt.savefig` to save the figure and view it in vscode.
64+
In this task we will go one step further. Instead of a cosine function, our neural network will learn how to identify handwritten digits from the [MNSIT dataset](http://yann.lecun.com/exdb/mnist/). For that, we will be using the [torch.nn](https://pytorch.org/docs/stable/nn.html) module. To get started familiarize yourself with the torch.nn to train a fully connected network in `src/mnist.py`. In this script, some functions are already implemented and can be reused. [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) is an elegant way to deal with data batches (Torch takes care of this for us). This task aims to compute gradients and update steps for all batches in the list. If you are coding on bender the function `matplotlib.pyplot.show` doesn't work if you are not connected to the X server of bender. Use e.g. `plt.savefig` to save the figure and view it in vscode.
6565

66-
- Implement the `normalize` function to ensure approximate standard-normal inputs. Make use of handy numpy methods that you already know. Normalization requires subtraction of the mean and division by the standard deviation with $i = 1, \dots w$ and $j = 1, \dots h$ with $w$ the image width and $h$ the image height and $k$ running through the batch dimension:
66+
- Implement the `normalize_batch` function to ensure approximate standard-normal inputs. Make use of handy torch inbuilt methods. Normalization requires subtraction of the mean and division by the standard deviation with $i = 1, \dots w$ and $j = 1, \dots h$ with $w$ the image width and $h$ the image height and $k$ running through the batch dimension:
6767

6868
```math
6969
\tilde{{x}}_{ijk} = \frac{x_{ijk} - \mu}{\sigma}
7070
```
7171

72-
- The forward step requires the `Net` object from its [class](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module). It is your fully connected neural network model. Applying weights to a `flax.linen.Module` is comparable to calculating the forward pass of the network in task 1. Implement a dense network in `Net` of your choosing using a combination of `flax.linen.Dense` and `flax.linen.activation.relu` or `flax.linen.sigmoid`.
72+
- The forward step requires the `Net` object from its [class](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#define-the-class). It is your fully connected neural network model. Implement a dense network in `Net` of your choosing using a combination of `torch.nn.Linear` and `th.nn.ReLU` or `th.nn.Sigmoid`
7373

74-
- The forward pass ends with the evaluation of a cost function. Write a `cross_entropy` cost function with $n_o$ the number of labels and $n_b$ in the batched case using
74+
- In `Net` class additionally, implement the `forward` function to compute the network forwad pass.
75+
76+
- Write a `cross_entropy` cost function with $n_o$ the number of labels and $n_b$ in the batched case using
7577

7678
```math
7779
C_{\text{ce}}(\mathbf{y},\mathbf{o})=-\frac{1}{n_b}\sum_{i=1}^{n_b}\sum_{k=1}^{n_o}[(\mathbf{y}_{i,k}\ln\mathbf{o}_{i,k})+(\mathbf{1}-\mathbf{y}_{i,k})\ln(\mathbf{1}-\mathbf{o}_{i,k})].
7880
```
7981

80-
- If you have chosen to work with ten output neurons. Use `jax.nn.one_hot` to encode the labels.
81-
82-
- Now implement the `forward_step` function. Calculate the network output first. Then compute the loss. It should return a scalar cost term you can use to compute gradients. Make use of the cross entropy.
82+
- If you have chosen to work with ten output neurons. Use `torch.nn.functional.one_hot` to encode the labels.
8383

84-
- Next we want to be able to do an optimization step with stochastic gradient descent (sgd). Implement `sgd_step`. Use the gradients to update the weights. Consider `jax.tree_util.tree_map` for this task. Treemaps work best with a lambda expression.
84+
- Next we want to be able to do an optimization step with stochastic gradient descent (sgd). Implement `sgd_step`. One way to do this is to iterate over `model.parameters()` and update each parameter individually with its gradient. One can access the gradient for each parameter with `<param>.grad`.
8585

86-
- To evaluate the network we calculate the accuracy of the network output. Implement `get_acc` to calculate the accuracy given a batch of images and the corresponding labels for these images.
86+
- To evaluate the network we calculate the accuracy of the network output. Implement `get_acc` to calculate the accuracy given a dataloader containing batches of images and corresponding labels. More about dataloaders is available [here](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).
8787

88-
- Now is the time to move back to the main procedure. First, the train data is fetched via the function `get_mnist_train_data`. To be able to evaluate the network while it is being trained, we use a validation set. Here the train set is split into two disjoint sets: the training and the validation set. Both sets must be normalized.
88+
- Now is the time to move back to the main procedure. First, the train data is fetched via the torchvision `torchvision.MNIST`. To be able to evaluate the network while it is being trained, we use a validation set. Here the train set is split into two disjoint sets: the training and the validation set using `torch.utils.data.random_split`.
8989

90-
- Define your loss and gradient function with jax (see task 1). Next, initialize the network with the `Net` object (see the `flax` documentation for help).
90+
- Initialize the network with the `Net` object (see the `torch` documentation for help).
9191

92-
- Train your network for a fixed number of `epochs` over the entire dataset.
92+
- Train your network for a fixed number of `EPCOHS` over the entire dataset.
9393

94-
- Last, load the test data with `get_mnist_test_data` and calculate the test accuracy. Save it to a list.
94+
- Last, load the test data with `test_loader` and calculate the test accuracy. Save it to a list.
9595

9696
- Optional: Plot the training and validation accuracies and add the test accuracy in the end.

0 commit comments

Comments
 (0)