Skip to content

Commit f215901

Browse files
committed
Update README for new CIFAR10 dataset support
1 parent 311a4ba commit f215901

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ The model was trained on the standard [MNIST](http://yann.lecun.com/exdb/mnist/)
3030

3131
*Note: you don't have to manually download, preprocess, and load the MNIST dataset as [TorchVision](https://github.com/pytorch/vision) will take care of this step for you.*
3232

33+
I have tried using other datasets. See the [Other Datasets](#other-datasets) section below for more details.
34+
3335
## Requirements
3436
- Python 3
3537
- Tested with version 3.6.4
3638
- [PyTorch](http://pytorch.org/)
37-
- Tested with version 0.2.0.post4 and 0.3.0.post4
39+
- Tested with version 0.3.0.post4
3840
- Code will not run with version 0.1.2 due to `keepdim` not available in this version.
41+
- Code will not run with version 0.2.0 due to `softmax` function doesn't takes a dimension.
3942
- CUDA 8 and above
4043
- Tested with CUDA 8 and CUDA 9.
4144
- [TorchVision](https://github.com/pytorch/vision)
@@ -106,6 +109,9 @@ Uncompress and put the weights (.pth files) into `./results/trained_model/`.
106109
| Num. routing iteration | 3 | --num-routing 3 |
107110
| Use reconstruction loss | true | --use-reconstruction-loss |
108111
| Regularization coefficient for reconstruction loss | 0.0005 | --regularization-scale 0.0005 |
112+
| Dataset name (mnist, cifar10) | mnist | --dataset mnist |
113+
| Input image width to the convolution | 28 | --input-width 28 |
114+
| Input image height to the convolution | 28 | --input-height 28 |
109115

110116
## Results
111117

@@ -160,7 +166,8 @@ Test loss. Lowest test error: 0.2002%
160166

161167
### Training Speed
162168

163-
Around `3.25s / batch` or `25min / epoch` on a single Testla K80 GPU.
169+
- Around `5.97s / batch` or `8min / epoch` on a single Tesla K80 GPU with batch size of 704.
170+
- Around `3.25s / batch` or `25min / epoch` on a single Tesla K80 GPUwith batch size of 128.
164171

165172
![](results/training_speed.png)
166173

@@ -251,7 +258,7 @@ decoder.fc2.bias: [1024]
251258
decoder.fc3.weight: [784, 1024]
252259
decoder.fc3.bias: [784]
253260

254-
Total number of parameters (with reconstruction network): 8227088 (8 million)
261+
Total number of parameters on (with reconstruction network): 8227088 (8 million)
255262
```
256263

257264
### TensorBoard
@@ -271,15 +278,27 @@ $ tensorboard --logdir runs
271278
```
272279
5. Open TensorBoard dashboard in your web browser using this URL: http://localhost:6006
273280

281+
### Other Datasets
282+
283+
#### CIFAR10
284+
285+
In the spirit of experiment, I have tried using other datasets. I have updated the implementation so that it supports and works with CIFAR10. Need to note that I have not tested throughly our capsule model on CIFAR10.
286+
287+
Here's how we can train and test the model on CIFAR10 by running the following commands.
288+
289+
```bash
290+
python main.py --dataset cifar10 --num-conv-in-channel 3 --input-width 32 --input-height 32 --primary-unit-size 2048 --epochs 50 --num-routing 1 --use-reconstruction-loss yes --regularization-scale 0.0005
291+
```
292+
274293
## TODO
275294
- [x] Publish results.
276295
- [x] More testing.
277296
- [ ] Inference mode - command to test a pre-trained model.
278297
- [ ] Jupyter Notebook version.
279298
- [ ] Create a sample to show how we can apply CapsNet to real-world application.
280299
- [ ] Experiment with CapsNet:
281-
* Try using another dataset.
282-
* Come out a more creative model structure.
300+
* [x] Try using another dataset.
301+
* [ ] Come out a more creative model structure.
283302
- [x] Pre-trained model and weights.
284303
- [x] Add visualization for training and evaluation metrics.
285304
- [x] Implement recontruction loss.

0 commit comments

Comments
 (0)