Skip to content

Commit de089fa

Browse files
committed
support 1dcnn distilling
Signed-off-by: priscilla-pan <pan.jiayi@zte.com.cn>
1 parent 385fcf2 commit de089fa

27 files changed

+547
-19
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ The details are shown in the table below, and the code can refer to examples\res
6565
| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M|
6666
| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M|
6767

68+
We also impletement a 1D-CNN distillation which shows distillation is also effective on Encrypted Traffic Classification.
69+
You can get detailed instructions from [here](doc/CNN-1D-tiny-Distillation.md). Following this instruction, you can build
70+
your own dataset and model to train and distill under adlik model optimizer.
71+
6872
## 1. Pruning and quantization principle
6973

7074
### 1.1 Filter pruning

doc/CNN-1D-tiny-Distillation.md

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Tiny 1D-CNN Knowledge Distillation
2+
3+
The following uses 1D-CNN on the 12 classes session all dataset as teacher model to illustrate how to use the model
4+
optimizer to improve the preformance of tiny 1D-CNN by knowledge distillation.
5+
6+
The 1D-CNN model is from Wang's paper[Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic
7+
classification with one-dimensional convolution neural networks.] The tiny 1D-CNN model is a slim version of the
8+
1D-CNN model mentioned before. Using 1D-CNN model as the teacher to ditstill tiny 1D-CNN model, performance can be
9+
improved by 5.66%.
10+
11+
The details are shown in the table below, and the code can refer to examples/cnn1d_tiny_iscx_session_all_distill.py.
12+
13+
| Model | Accuracy | Params | Model Size |
14+
| --------- | -------- | -------------------- | ---------------------------- |
15+
| cnn1d | 92.67% | 5832588 | 23M|
16+
| cnn1d_tiny | 87.62% | 134988 | 546K|
17+
| cnn1d_tiny+ distill | 93.28% | 134988 | 546K|
18+
19+
## 1 Create custom dataset
20+
21+
Using [ISCX dataset](https://www.unb.ca/cic/datasets/vpn.html), you can get the processed 12-classes-session-all dataset
22+
from [wang's github](https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip).
23+
We name the dataset as iscx_session_all. In the iscx_session_all, there are 35501 training samples, the shape is
24+
(35501, 28, 28), 3945 testing samples.
25+
26+
Now that you have the dataset, you can implement your custom dateset by extending model_optimizer.prunner.dataset.
27+
dataset_base.DatasetBase and implementing:
28+
29+
1. \__init__, required, where you can do all dataset initialization
30+
2. parse_fn, required, where is the map function of the dataset
31+
3. parse_fn_distill, required, where is the map function of the dataset used in distillation
32+
4. build, optional, where is the process of building the dataset. If your dataset is not in tfrecord format, you must
33+
implement this function.
34+
35+
Here in the custom dataset, we reshape the samples from (None, 28, 28, 1) to (None, 1, 784, 1) for the following 1D-CNN
36+
models.
37+
38+
After that, all you need is put the dataset name in the following files:
39+
40+
1. src/model_optimizer/prunner/config_schema.json the "enum" list
41+
2. src/model_optimizer/prunner/dataset/\__init__.py. Add the dataset name in Line 19 and add the dataset instance in the
42+
if-else clause.
43+
44+
## Create custom model
45+
46+
Create your own model using The Keras functional API in model_optimizer.prunner.models.
47+
Here we create cnn1d model and cnn1d_tiny model in model_optimizer.prunner.models.cnn1d.
48+
49+
After that, all you need is put the model name and initialize the model in the following files:
50+
51+
1. src/model_optimizer/prunner/models/\__init__.py. Add the model name in Line 21 and add the model instance in the
52+
if-else clause.
53+
54+
## Create custom learner
55+
56+
Implement your own learner by extending model_optimizer.prunner.learner.learner_base.LearnerBase and implementing:
57+
58+
1. \__init__, required, where you can define your own learning rate callback
59+
2. get_optimizer, required, where you can define your own optimizer
60+
3. get_losses, required, where you can define your own loss function
61+
4. get_metrics, required, where you can define your own metrics
62+
63+
After that, all you need is put the model name and dataset name and initialize the learner in the following files:
64+
65+
1. src/model_optimizer/prunner/learner/\__init__.py
66+
67+
## Create the training process of the teacher model, and train the teacher model
68+
69+
Enter the examples directory, create cnn1d_iscx_session_all_train.py for cnn1d model.
70+
71+
> Note
72+
>
73+
> > the "model_name" and "dataset" in the request must be the same as you defined before
74+
75+
Execute:
76+
77+
```shell
78+
cd examples
79+
python3 cnn1d_iscx_session_all_train.py
80+
```
81+
82+
After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d, and the inference
83+
checkpoint file will be generated in ./models_eval_ckpt/cnn1d. You can also modify the checkpoint_path
84+
and checkpoint_eval_path of the cnn1d_iscx_session_all_train.py file to change the generated file path.
85+
86+
## Convert the teacher model to logits output
87+
88+
Enter the tools directory and execute:
89+
90+
```shell
91+
cd tools
92+
python3 convert_softmax_model_to_logits.py
93+
```
94+
95+
After execution, the default checkpoint file of logits model will be generated in examples/models_eval_ckpt/cnn1d/
96+
checkpoint-60-logits.h5
97+
98+
## Create the distilling process and distill the cnn1d_tiny model
99+
100+
Create the configuration file in the src/model_optimizer/pruner/scheduler/distill,like "cnn1d_tiny_0.3.yaml" where the
101+
distillation parameters is configured.
102+
103+
Enter the examples directory, create cnn1d_tiny_iscx_session_all_distill.py for cnn1d_tiny model. In the distilling
104+
process, the teacher is cnn1d, the student is cnn1d_tiny.
105+
106+
> Note
107+
>
108+
> > the "model_name" and "dataset" in the request must be the same as you defined before
109+
110+
```shell
111+
python3 cnn1d_tiny_iscx_session_all_distill.py
112+
```
113+
114+
After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d_tiny_distill, and the inference
115+
checkpoint file will be generated in ./models_eval_ckpt/cnn1d_tiny_distill. You can also modify the checkpoint_path and
116+
checkpoint_eval_path of the cnn1d_tiny_iscx_session_all_distill.py file to change the generated file path.
117+
118+
> Note
119+
>
120+
> > i. The model in the checkpoint_path is not the pure cnn1d_tiny model. It's the hybird of cnn1d_tiny(student) and
121+
> > cnn1d(teacher)
122+
> >
123+
> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure cnn1d_tiny model
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Train a cnn1d model on iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 60,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d"),
28+
"checkpoint_save_period": 1, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d"),
30+
"scheduler": "train"
31+
}
32+
prune_model(request)
33+
34+
35+
if __name__ == "__main__":
36+
_main()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Distill a cnn1d_tiny model from a trained cnn1d model on the iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d_tiny",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 200,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny_distill"),
28+
"checkpoint_save_period": 10, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny_distill"),
30+
"scheduler": "distill",
31+
"scheduler_file_name": "cnn1d_tiny_0.3.yaml"
32+
}
33+
prune_model(request)
34+
35+
36+
if __name__ == "__main__":
37+
_main()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Train a cnn1d_tiny model on iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d_tiny",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 60,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny"),
28+
"checkpoint_save_period": 10, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny"),
30+
"scheduler": "train"
31+
}
32+
prune_model(request)
33+
34+
35+
if __name__ == "__main__":
36+
_main()

src/model_optimizer/pruner/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def create_config_from_obj(obj) -> object:
3636
:return:
3737
"""
3838
schema_path = os.path.join(os.path.dirname(__file__), 'config_schema.json')
39-
with open(schema_path) as schema_file:
39+
with open(schema_path, encoding='utf-8') as schema_file:
4040
body_schema = json.load(schema_file)
4141

4242
jsonschema.validate(obj, body_schema)

src/model_optimizer/pruner/config_schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"enum": [
77
"mnist",
88
"cifar10",
9-
"imagenet"
9+
"imagenet",
10+
"iscx_session_all"
1011
],
1112
"description": "dataset name"
1213
},

src/model_optimizer/pruner/core/pruner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_network(model):
2525
for i, layer in enumerate(model.layers):
2626
digraph.add_node(i, name=layer.name, type=str(type(layer)))
2727
for i, layer in enumerate(model.layers):
28-
for j in range(0, len(model.layers)):
28+
for j, _ in enumerate(model.layers):
2929
_inputs = model.layers[j].input
3030
if isinstance(_inputs, list):
3131
for _input in _inputs:

src/model_optimizer/pruner/dataset/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
1616
:return: class of Dataset
1717
"""
1818
dataset_name = config.get_attribute('dataset')
19-
if dataset_name not in ['mnist', 'cifar10', 'imagenet']:
19+
if dataset_name not in ['mnist', 'cifar10', 'imagenet', 'iscx_session_all']:
2020
raise Exception('Not support dataset %s' % dataset_name)
2121
if dataset_name == 'mnist':
2222
from .mnist import MnistDataset
@@ -27,5 +27,8 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
2727
elif dataset_name == 'imagenet':
2828
from .imagenet import ImagenetDataset
2929
return ImagenetDataset(config, is_training, num_shards, shard_index)
30+
elif dataset_name == 'iscx_session_all':
31+
from .iscx_session_all import ISCXDataset
32+
return ISCXDataset(config, is_training, num_shards, shard_index)
3033
else:
3134
raise Exception('Not support dataset {}'.format(dataset_name))

src/model_optimizer/pruner/dataset/cifar10.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, config, is_training):
3131
self.buffer_size = 10000
3232
self.num_samples_of_train = 50000
3333
self.num_samples_of_val = 10000
34+
self.data_shape = (32, 32, 3)
3435

3536
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
3637
def parse_fn(self, example_serialized):

src/model_optimizer/pruner/dataset/dataset_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def build(self, is_distill=False):
7878
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
7979
else:
8080
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
81-
return self.__build_batch(dataset)
81+
return self.build_batch(dataset)
8282

83-
def __build_batch(self, dataset):
83+
def build_batch(self, dataset):
8484
"""
8585
Make an batch from tf.data.Dataset.
8686
:param dataset: tf.data.Dataset object

src/model_optimizer/pruner/dataset/imagenet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
3232
self.buffer_size = 10000
3333
self.num_samples_of_train = 1281167
3434
self.num_samples_of_val = 50000
35+
self.data_shape = (224, 224, 3)
3536

3637
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
3738
def parse_fn(self, example_serialized):

0 commit comments

Comments
 (0)