Skip to content

Commit 1d86e9d

Browse files
updated README.md
1 parent 1e2f924 commit 1d86e9d

File tree

6 files changed

+139
-76
lines changed

6 files changed

+139
-76
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ If you use software from this library in your work, please use the BibTex entry
7575
```
7676
@software{edgeml01,
7777
author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Gupta, Chirag and
78-
Kumar, Ashish and Kusupati, Aditya and Lovett, Chris and Patil, Shishir G and Simhadri, Harsha Vardhan}},
78+
Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and Lovett, Chris
79+
and Patil, Shishir G and Simhadri, Harsha Vardhan}},
7980
title = {{EdgeML: Machine Learning for resource-constrained edge devices}},
8081
url = {https://github.com/Microsoft/EdgeML},
8182
version = {0.2},

examples/pytorch/Bonsai/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use-case on the USPS10 public dataset.
77
`edgeml_pytorch.graph.bonsai` implements the Bonsai prediction graph in pytorch.
88
The three-phase training routine for Bonsai is decoupled from the forward graph
99
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
10-
used as a final layer classifier for other architectures (RNNs, CNNs).
10+
used as a final layer classifier for other architectures (RNNs, CNNs).
11+
See `edgeml_pytorch.trainer.bonsaiTrainer` for 3-phase training.
1112

1213
Note that `bonsai_example.py` assumes that data is in a specific format. It is
1314
assumed that train and test data is contained in two files, `train.npy` and

examples/pytorch/FastCells/README.md

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,90 @@
11
# EdgeML FastCells on a sample public dataset
22

3-
This directory includes example notebook and general execution script of
4-
FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
3+
This directory includes example notebooks and scripts of
4+
FastCells (FastRNN & FastGRNN) along with modified
55
UGRNN, GRU and LSTM to support the LSQ training routine.
6-
Also, we include a sample cleanup and use-case on the USPS10 public dataset.
7-
8-
`edgeml_pytorch.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../pytorch_edgeml/graph/rnn.py#L226)) and **FastGRNN** ([`FastGRNNCell`](../../pytorch_edgeml/graph/rnn.py#L80)) with
9-
multiple additional features like Low-Rank parameterisation, custom
10-
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
11-
routine for FastRNN and FastGRNN is decoupled from the custom cells to
12-
facilitate a plug and play behaviour of the custom RNN cells in other
13-
architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc.,
14-
`edgeml_pytorch.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../pytorch_edgeml/graph/rnn.py#L742)),
15-
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L565)) and **LSTM** ([`LSTMLRCell`](../../pytorch_edgeml/graph/rnn.py#L369)). These cells also can be substituted for FastCells where ever feasible.
16-
17-
`edgeml_pytorch.graph.rnn` also contains fully wrapped RNNs which are equivalent to `nn.LSTM` and `nn.GRU`. Implemented cells:
18-
**FastRNN** ([`FastRNN`](../../pytorch_edgeml/graph/rnn.py#L968)), **FastGRNN** ([`FastGRNN`](../../pytorch_edgeml/graph/rnn.py#L993)), **UGRNN** ([`UGRNN`](../../edgeml_pytorch/graph/rnn.py#L945)), **GRU** ([`GRU`](../../edgeml/graph/rnn.py#L922)) and **LSTM** ([`LSTM`](../../pytorch_edgeml/graph/rnn.py#L899)).
19-
20-
Note that all the cells and wrappers (when used independently from `fastcell_example.py` or `edgeml_pytorch.trainer.fastTrainer`) take in data in a batch first format ie., [batchSize, timeSteps, inputDims] by default but it can also support [timeSteps, batchSize, inputDims] format by setting `batch_first` argument to False when used. `fast_example.py` automatically takes care it while assuming the standard format between tf, c++ and pytorch.
6+
There is also a sample cleanup and train/test script for the USPS10 public dataset.
7+
8+
[`edgeml_pytorch.graph.rnn`](../../../pytorch/pytorch_edgeml/graph/rnn.py)
9+
provides two RNN cells **FastRNNCell** and **FastGRNNCell** with additional
10+
features like low-rank parameterisation and custom non-linearities. Akin to
11+
Bonsai and ProtoNN, the three-phase training routine for FastRNN and FastGRNN
12+
is decoupled from the custom cells to facilitate a plug and play behaviour of
13+
the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.).
14+
Additionally, numerically equivalent CUDA-based implementations FastRNNCuda
15+
and FastGRNNCuda are provided for faster training.
16+
`edgeml_pytorch.graph.rnn` also contains modified RNN cells of **UGRNNCell**,
17+
**GRUCell**, and **LSTMCell**, which can be substituted for Fast(G)RNN,
18+
as well as untrolled RNNs which are equivalent to `nn.LSTM` and `nn.GRU`.
19+
20+
Note that all the cells and wrappers, when used independently from `fastcell_example.py`
21+
or `edgeml_pytorch.trainer.fastTrainer`, take in data in a batch first format, i.e.,
22+
[batchSize, timeSteps, inputDims] by default, but can also support [timeSteps,
23+
batchSize, inputDims] format if `batch_first` argument is set to False.
24+
`fast_example.py` automatically adjusts to the correct format across tf, c++ and pytorch.
2125

2226
For training FastCells, `edgeml_pytorch.trainer.fastTrainer` implements the three-phase
23-
FastCell training routine in PyTorch. A simple example,
24-
`examples/fastcell_example.py` is provided to illustrate its usage.
25-
26-
Note that `fastcell_example.py` assumes that data is in a specific format. It
27-
is assumed that train and test data is contained in two files, `train.npy` and
28-
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
27+
FastCell training routine in PyTorch. A simple example `fastcell_example.py` is provided
28+
to illustrate its usage. Note that `fastcell_example.py` assumes that data is in a specific format.
29+
It is assumed that train and test data is contained in two files, `train.npy` and
30+
`test.npy`, each containing a 2D numpy array of dimension `[numberOfExamples,
2931
numberOfFeatures]`. numberOfFeatures is `timesteps x inputDims`, flattened
30-
across timestep dimension. So the input of 1st timestep followed by second and
31-
so on. For an N-Class problem, we assume the labels are integers from 0
32+
across timestep dimension with the input of the first time step followed by the second
33+
and so on. For an N-Class problem, we assume the labels are integers from 0
3234
through N-1. Lastly, the training data, `train.npy`, is assumed to well shuffled
3335
as the training routine doesn't shuffle internally.
3436

3537
**Tested With:** PyTorch = 1.1 with Python 3.6
3638

3739
## Download and clean up sample dataset
3840

39-
We will be testing out the validation of the code by using the USPS dataset.
40-
The download and cleanup of the dataset to match the above-mentioned format is
41-
done by the script [fetch_usps.py](fetch_usps.py) and
41+
To validate the code with USPS dataset, first download and format the dataset to match
42+
the required format using the script [fetch_usps.py](fetch_usps.py) and
4243
[process_usps.py](process_usps.py)
4344

4445
```
4546
python fetch_usps.py
4647
python process_usps.py
4748
```
4849

50+
Note: Even though usps10 is not a time-series dataset, it can be regarding as a time-series
51+
dataset where time step sees a new row. So the number of timesteps = 16 and inputDims = 16.
4952

5053
## Sample command for FastCells on USPS10
51-
The following sample run on usps10 should validate your library:
52-
53-
Note: Even though usps10 is not a time-series dataset, it can be assumed as, a time-series where each row is coming in at one single time.
54-
So the number of timesteps = 16 and inputDims = 16
54+
The following is a sample run on usps10 :
5555

5656
```bash
5757
python fastcell_example.py -dir usps10/ -id 16 -hd 32
5858
```
59-
This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches):
59+
This command should give you a final output that reads roughly similar to
60+
(might not be exact numbers due to various version mismatches):
6061

6162
```
6263
Maximum Test accuracy at compressed model size(including early stopping): 0.9407075 at Epoch: 262
6364
Final Test Accuracy: 0.93721974
6465
6566
Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False
6667
```
67-
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or `FastGRNNResults.txt` depending on the choice of the RNN cell.
68-
A directory `FastRNNResults` or `FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset.
68+
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or
69+
`FastGRNNResults.txt` depending on the choice of the RNN cell. A directory `FastRNNResults` or
70+
`FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset.
6971

70-
Note that the scalars like `alpha`, `beta`, `zeta` and `nu` are all before the application of the sigmoid function over them.
72+
Note that the scalars like `alpha`, `beta`, `zeta` and `nu` correspond to the values before
73+
the application of the sigmoid function.
7174

7275
## Byte Quantization(Q) for model compression
73-
If you wish to quantize the generated model to use byte quantized integers use `quantizeFastModels.py`. Usage Instructions:
76+
If you wish to quantize the generated model, use `quantizeFastModels.py`. Usage Instructions:
7477

7578
```
7679
python quantizeFastModels.py -h
7780
```
7881

79-
This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedFastModel` inside the model directory.
80-
One can use this model further on edge devices.
82+
This will generate quantized models with a suffix of `q` before every param stored in a
83+
new directory `QuantizedFastModel` inside the model directory.
8184

82-
Note that the scalars like `qalpha`, `qbeta`, `qzeta` and `qnu` are all after the application of the sigmoid function over them and quantization, they can be directly plugged into the inference pipleines.
85+
Note that the scalars like `qalpha`, `qbeta`, `qzeta` and `qnu` correspond to values
86+
after the application of the sigmoid function over them post quantization;
87+
they can be directly plugged into the inference pipleines.
8388

8489
Copyright (c) Microsoft Corporation. All rights reserved.
85-
8690
Licensed under the MIT license.

pytorch/README.md

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,39 @@
11
## Edge Machine Learning: Pytorch Library
22

3-
This directory includes PyTorch implementations of various techniques and
4-
algorithms developed as part of EdgeML. Currently, the following algorithms are
5-
available in Tensorflow:
6-
7-
1. [Bonsai](/docs/publications/Bonsai.pdf)
8-
2. S-RNN
9-
3. [FastRNN & FastGRNN](/docs/publications/FastGRNN.pdf)
10-
4. [ProtoNN](/docs/publications/ProtoNN.pdf)
11-
12-
The PyTorch graphs for these algoriths are packaged as `edgeml_pytorch.graph`.
13-
Trainers for these algorithms are in `edgeml_pytorch.trainer`.
14-
Usage directions and examples for these algorithms are provided in
15-
`$EDGEML_ROOT/examples/pytorch` directory. To get started with any
16-
of the provided algorithms, please follow the notebooks in the the
17-
`examples/pytorch` directory.
3+
This package includes PyTorch implementations of following algorithms and training
4+
techniques developed as part of EdgeML. The PyTorch graphs for the forward/backward
5+
pass of these algorithms are packaged as `edgeml_pytorch.graph` and the trainers
6+
for these algorithms are in `edgeml_pytorch.trainer`.
187

19-
## Installation
8+
1. [Bonsai](/docs/publications/Bonsai.pdf): `edgeml_pytorch.graph.bonsai` implements
9+
the Bonsai prediction graph. The three-phase training routine for Bonsai is decoupled
10+
from the forward graph to facilitate a plug and play behaviour wherein Bonsai can be
11+
combined with or used as a final layer classifier for other architectures (RNNs, CNNs).
12+
See `edgeml_pytorch.trainer.bonsaiTrainer` for 3-phase training.
13+
2. [ProtoNN](/docs/publications/ProtoNN.pdf): `edgeml_pytorch.graph.protoNN` implements the
14+
ProtoNN prediction functions. The training routine for ProtoNN is decoupled from the forward
15+
graph to facilitate a plug and play behaviour wherein ProtoNN can be combined with or used
16+
as a final layer classifier for other architectures (RNNs, CNNs). The training routine is
17+
implemented in `edgeml_pytorch.trainer.protoNNTrainer`.
18+
3. [FastRNN & FastGRNN](/docs/publications/FastGRNN.pdf): `edgeml_pytorch.graph.rnn` provides
19+
various RNN cells --- including new cells `FastRNNCell` and `FastGRNNCell` as well as
20+
`UGRNNCell`, `GRUCell`, and `LSTMCell` --- with features like low-rank parameterisation
21+
of weight matrices and custom non-linearities. Akin to Bonsai and ProtoNN, the three-phase
22+
training routine for FastRNN and FastGRNN is decoupled from the custom cells to enable plug and
23+
play behaviour of the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.).
24+
Additionally, numerically equivalent CUDA-based implementations `FastRNNCUDACell` and
25+
`FastGRNNCUDACell` are provided for faster training. `edgeml_pytorch.graph.rnn`.
26+
`edgeml_pytorch.graph.rnn.Fast(G)RNN(CUDA)` provides unrolled RNNs equivalent to `nn.LSTM` and `nn.GRU`.
27+
`edgeml_pytorch.trainer.fastmodel` presents a sample multi-layer RNN + multi-class classifier model.
28+
4. [S-RNN](/docs/publications/SRNN.pdf): `edgeml_pytorch.graph.rnn.SRNN2` implements a
29+
2 layer SRNN network which can be instantied with a choice of RNN cell. The training
30+
routine for SRNN is in `edgeml_pytorch.trainer.srnnTrainer`.
31+
32+
Usage directions and examples notebooks for this package are provided [here](/examples/pytorch).
2033

2134

35+
## Installation
36+
2237
It is highly recommended that EdgeML be installed in a virtual environment.
2338
Please create a new virtual environment using your environment manager
2439
([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or

pytorch/edgeml_pytorch/trainer/fastmodel.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, rnn_name, input_dim, num_layers, hidden_units_list,
3838
self.linear = linear
3939
self.batch_first = batch_first
4040
self.apply_softmax = apply_softmax
41+
self.rnn_name = rnn_name
4142

4243
if self.linear:
4344
if not self.num_classes:
@@ -57,6 +58,18 @@ def __init__(self, rnn_name, input_dim, num_layers, hidden_units_list,
5758
batch_first = self.batch_first)
5859
for l in range(self.num_layers)])
5960

61+
if rnn_name == "FastGRNNCUDA":
62+
RNN_ = getattr(getattr(getattr(__import__('edgeml_pytorch'), 'graph'), 'rnn'), 'FastGRNN')
63+
self.rnn_list_ = nn.ModuleList([
64+
RNN_(self.input_dim if l==0 else self.hidden_units_list[l-1],
65+
self.hidden_units_list[l],
66+
gate_nonlinearity=self.gate_nonlinearity,
67+
update_nonlinearity=self.update_nonlinearity,
68+
wRank=self.wRank_list[l], uRank=self.uRank_list[l],
69+
wSparsity=self.wSparsity_list[l],
70+
uSparsity=self.uSparsity_list[l],
71+
batch_first = self.batch_first)
72+
for l in range(self.num_layers)])
6073
# The linear layer is a fully connected layer that maps from hidden state space
6174
# to number of expected keywords
6275
if self.linear:
@@ -66,16 +79,30 @@ def __init__(self, rnn_name, input_dim, num_layers, hidden_units_list,
6679

6780
def sparsify(self):
6881
for rnn in self.rnn_list:
69-
rnn.cell.sparsify()
82+
if self.rnn_name is "FastGRNNCUDA":
83+
rnn.to(torch.device("cpu"))
84+
rnn.sparsify()
85+
rnn.to(torch.device("cuda"))
86+
else:
87+
rnn.cell.sparsify()
7088

7189
def sparsifyWithSupport(self):
7290
for rnn in self.rnn_list:
73-
rnn.cell.sparsifyWithSupport()
91+
if self.rnn_name is "FastGRNNCUDA":
92+
rnn.to(torch.device("cpu"))
93+
rnn.sparsifyWithSupport()
94+
rnn.to(torch.device("cuda"))
95+
else:
96+
rnn.cell.sparsifyWithSupport()
7497

7598
def get_model_size(self):
7699
total_size = 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
100+
print(self.rnn_name)
77101
for rnn in self.rnn_list:
78-
total_size += rnn.cell.get_model_size()
102+
if self.rnn_name == "FastGRNNCUDA":
103+
total_size += rnn.get_model_size()
104+
else:
105+
total_size += rnn.cell.get_model_size()
79106
return total_size
80107

81108
def normalize(self, mean, std):
@@ -130,15 +157,32 @@ def forward(self, input):
130157
input = (input - self.mean) / self.std
131158

132159
rnn_in = input
133-
for l in range(self.num_layers):
134-
rnn = self.rnn_list[l]
135-
model_output = rnn(rnn_in, hiddenState=self.hidden_states[l])
136-
self.hidden_states[l] = model_output.detach()[-1, :, :]
160+
if self.rnn_name == "FastGRNNCUDA":
137161
if self.tracking:
138-
weights = rnn.getVars()
139-
model_output = onnx_exportable_rnn(rnn_in, weights,
140-
rnn.cell, output=model_output)
141-
rnn_in = model_output
162+
for l in range(self.num_layers):
163+
print("Layer: ", l)
164+
rnn_ = self.rnn_list_[l]
165+
model_output = rnn_(rnn_in, hiddenState=self.hidden_states[l])
166+
self.hidden_states[l] = model_output.detach()[-1, :, :]
167+
weights = self.rnn_list[l].getVars()
168+
weights = [weight.clone() for weight in weights]
169+
model_output = onnx_exportable_rnn(rnn_in, weights, rnn_.cell, output=model_output)
170+
rnn_in = model_output
171+
else:
172+
for l in range(self.num_layers):
173+
rnn = self.rnn_list[l]
174+
model_output = rnn(rnn_in, hiddenState=self.hidden_states[l])
175+
self.hidden_states[l] = model_output.detach()[-1, :, :]
176+
rnn_in = model_output
177+
else:
178+
for l in range(self.num_layers):
179+
rnn = self.rnn_list[l]
180+
model_output = rnn(rnn_in, hiddenState=self.hidden_states[l])
181+
self.hidden_states[l] = model_output.detach()[-1, :, :]
182+
if self.tracking:
183+
weights = rnn.getVars()
184+
model_output = onnx_exportable_rnn(rnn_in, weights, rnn.cell, output=model_output)
185+
rnn_in = model_output
142186

143187
if self.linear:
144188
model_output = self.hidden2keyword(model_output[-1, :, :])

tf/README.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@ available in Tensorflow:
99
3. [FastRNN & FastGRNN](/docs/publications/FastGRNN.pdf)
1010
4. [ProtoNN](/docs/publications/ProtoNN.pdf)
1111

12-
The TensorFlow compute graphs for these algoriths are packaged as
13-
`edgeml_tf.graph`. Trainers for these algorithms are in `edgeml_tf.trainer`.
14-
Usage directions and examples for these algorithms are provided in
15-
`$EDGEML_ROOT/examples/tf` directory.
16-
To get started with any of the provided algorithms, please follow
17-
the notebooks in the `examples/tf` directory.
12+
The TensorFlow compute graphs for these algoriths are packaged as `edgeml_tf.graph`
13+
and trainers are in `edgeml_tf.trainer`. Usage directions and example notebook for
14+
these algorithms are provided in the [examples/tf directory](/examples/tf).
15+
1816

1917
## Installation
2018

0 commit comments

Comments
 (0)