Skip to content

Commit 29bf95b

Browse files
committed
support 1dcnn distilling
Signed-off-by: priscilla-pan <pan.jiayi@zte.com.cn>
1 parent 385fcf2 commit 29bf95b

21 files changed

+519
-12
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ The details are shown in the table below, and the code can refer to examples\res
6565
| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M|
6666
| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M|
6767

68+
We also impletement a 1D-CNN distillation which shows distillation is also effective on Encrypted Traffic Classification.
69+
You can get detailed instructions from [here](doc/CNN-1D-tiny-Distillation.md). Following this instruction, you can build
70+
your own dataset and model to train and distill under adlik model optimizer.
71+
6872
## 1. Pruning and quantization principle
6973

7074
### 1.1 Filter pruning

doc/CNN-1D-tiny-Distillation.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Tiny 1D-CNN Knowledge Distillation
2+
3+
The following uses 1D-CNN on the 12 classes session all dataset as teacher model to illustrate how to use the model
4+
optimizer to improve the preformance of tiny 1D-CNN by knowledge distillation.
5+
6+
The 1D-CNN model is from Wang's paper[Wang, W.; Zhu, M.; Wang, J.; Zeng, X.; Yang, Z. End-to-end encrypted traffic
7+
classification with one-dimensional convolution neural networks.] The tiny 1D-CNN model is a slim version of the 1D-CNN
8+
model mentioned before. Using 1D-CNN model as the teacher to ditstill tiny 1D-CNN model, performance can be improved by
9+
5.66%.
10+
11+
The details are shown in the table below, and the code can refer to examples\cnn1d_tiny_iscx_session_all_distill.py.
12+
13+
| Model | Accuracy | Params | Model Size |
14+
| --------- | -------- | -------------------- | ---------------------------- |
15+
| cnn1d | 92.67% | 5832588 | 23M|
16+
| cnn1d_tiny | 87.62% | 134988 | 546K|
17+
| cnn1d_tiny+ distill | 93.28% | 134988 | 546K|
18+
19+
20+
## 1 Create custom dataset
21+
Using [ISCX dataset](https://www.unb.ca/cic/datasets/vpn.html), you can get the processed 12-classes-session-all dataset
22+
from [wang's github](https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip).
23+
We name the dataset as iscx_session_all. In the iscx_session_all, there are 35501 training samples, the shape is (35501, 28, 28),
24+
3945 testing samples.
25+
26+
Now that you have the dataset, you can implement your custom dateset by extending model_optimizer.prunner.dataset.
27+
dataset_base.DatasetBase and implementing:
28+
29+
1. \__init__, required, where you can do all dataset initialization
30+
2. parse_fn, required, where is the map function of the dataset
31+
3. parse_fn_distill, required, where is the map function of the dataset used in distillation
32+
4. build, optional, where is the process of building the dataset. If your dataset is not in tfrecord format, you must
33+
implement this function.
34+
35+
Here in the custom dataset, we reshape the samples from (None, 28, 28, 1) to (None, 1, 784, 1) for the following 1D-CNN
36+
models.
37+
38+
After that, all you need is put the dataset name in the following files:
39+
1. src/model_optimizer/prunner/config_schema.json the "enum" list
40+
2. src/model_optimizer/prunner/dataset/\__init__.py. Add the dataset name in Line 19 and add the dataset instance in the
41+
if-else clause.
42+
43+
## Create custom model
44+
Create your own model using The Keras functional API in model_optimizer.prunner.models.
45+
46+
After that, all you need is put the model name and initialize the model in the following files:
47+
1. src/model_optimizer/prunner/models/\__init__.py. Add the model name in Line 21 and add the model instance in the
48+
if-else clause.
49+
50+
## Create custom learner
51+
Implement your own learner by extending model_optimizer.prunner.learner.learner_base.LearnerBase and implementing:
52+
1. \__init__, required, where you can define your own learning rate callback
53+
2. get_optimizer, required, where you can define your own optimizer
54+
3. get_losses, required, where you can define your own loss function
55+
4. get_metrics, required, where you can define your own metrics
56+
57+
After that, all you need is put the model name and dataset name and initialize the learner in the following files:
58+
1. src/model_optimizer/prunner/learner/\__init__.py
59+
60+
## Create the training process of the teacher model, and train the teacher model
61+
Enter the examples directory, create cnn1d_iscx_session_all_train.py for cnn1d model.
62+
63+
> Note
64+
>
65+
> > the "model_name" and "dataset" in the request must be the same as you defined before
66+
67+
Execute:
68+
69+
```shell
70+
cd examples
71+
python3 cnn1d_iscx_session_all_train.py
72+
```
73+
74+
After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d, and the inference
75+
checkpoint file will be generated in ./models_eval_ckpt/cnn1d. You can also modify the checkpoint_path
76+
and checkpoint_eval_path of the cnn1d_iscx_session_all_train.py file to change the generated file path.
77+
78+
## Convert the teacher model to logits output
79+
Enter the tools directory and execute:
80+
```shell
81+
cd tools
82+
python3 convert_softmax_model_to_logits.py
83+
```
84+
85+
After execution, the default checkpoint file of logits model will be generated in examples/models_eval_ckpt/cnn1d/
86+
checkpoint-60-logits.h5
87+
88+
## Create the distilling process and distill the cnn1d_tiny model
89+
Create the configuration file in the src/model_optimizer/pruner/scheduler/distill,like "cnn1d_tiny_0.3.yaml" where the
90+
distillation parameters is configured.
91+
92+
Enter the examples directory, create cnn1d_tiny_iscx_session_all_distill.py for cnn1d_tiny model. In the distilling
93+
process, the teacher is cnn1d, the student is cnn1d_tiny.
94+
95+
> Note
96+
>
97+
> > the "model_name" and "dataset" in the request must be the same as you defined before
98+
99+
```shell
100+
python3 cnn1d_tiny_iscx_session_all_distill.py
101+
```
102+
103+
After execution, the default checkpoint file will be generated in ./models_ckpt/cnn1d_tiny_distill, and the inference
104+
checkpoint file will be generated in ./models_eval_ckpt/cnn1d_tiny_distill. You can also modify the checkpoint_path and
105+
checkpoint_eval_path of the cnn1d_tiny_iscx_session_all_distill.py file to change the
106+
generated file path.
107+
108+
> Note
109+
>
110+
> > i. The model in the checkpoint_path is not the pure cnn1d_tiny model. It's the hybird of cnn1d_tiny(student) and
111+
> > cnn1d(teacher)
112+
> >
113+
> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure cnn1d_tiny model
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Train a cnn1d model on iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 60,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d"),
28+
"checkpoint_save_period": 1, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d"),
30+
"scheduler": "train"
31+
}
32+
prune_model(request)
33+
34+
35+
if __name__ == "__main__":
36+
_main()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Distill a cnn1d_tiny model from a trained cnn1d model on the iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d_tiny",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 200,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny_distill"),
28+
"checkpoint_save_period": 10, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny_distill"),
30+
"scheduler": "distill",
31+
"scheduler_file_name": "cnn1d_tiny_0.3.yaml"
32+
}
33+
prune_model(request)
34+
35+
36+
if __name__ == "__main__":
37+
_main()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Train a cnn1d_tiny model on iscx_session_all dataset
6+
"""
7+
import os
8+
# If you did not execute the setup.py, uncomment the following four lines
9+
# import sys
10+
# from os.path import abspath, join, dirname
11+
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
12+
# print(sys.path)
13+
14+
from model_optimizer import prune_model # noqa: E402
15+
16+
17+
def _main():
18+
base_dir = os.path.dirname(__file__)
19+
request = {
20+
"dataset": "iscx_session_all",
21+
"model_name": "cnn1d_tiny",
22+
"data_dir": "/data/12class/SessionAllLayers",
23+
"batch_size": 500,
24+
"batch_size_val": 100,
25+
"learning_rate": 1e-3,
26+
"epochs": 60,
27+
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/cnn1d_tiny"),
28+
"checkpoint_save_period": 10, # save a checkpoint every epoch
29+
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/cnn1d_tiny"),
30+
"scheduler": "train"
31+
}
32+
prune_model(request)
33+
34+
35+
if __name__ == "__main__":
36+
_main()

src/model_optimizer/pruner/config_schema.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"enum": [
77
"mnist",
88
"cifar10",
9-
"imagenet"
9+
"imagenet",
10+
"iscx_session_all"
1011
],
1112
"description": "dataset name"
1213
},

src/model_optimizer/pruner/dataset/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
1616
:return: class of Dataset
1717
"""
1818
dataset_name = config.get_attribute('dataset')
19-
if dataset_name not in ['mnist', 'cifar10', 'imagenet']:
19+
if dataset_name not in ['mnist', 'cifar10', 'imagenet', 'iscx_session_all']:
2020
raise Exception('Not support dataset %s' % dataset_name)
2121
if dataset_name == 'mnist':
2222
from .mnist import MnistDataset
@@ -27,5 +27,8 @@ def get_dataset(config, is_training, num_shards=1, shard_index=0):
2727
elif dataset_name == 'imagenet':
2828
from .imagenet import ImagenetDataset
2929
return ImagenetDataset(config, is_training, num_shards, shard_index)
30+
elif dataset_name == 'iscx_session_all':
31+
from .iscx_session_all import ISCXDataset
32+
return ISCXDataset(config, is_training, num_shards, shard_index)
3033
else:
3134
raise Exception('Not support dataset {}'.format(dataset_name))

src/model_optimizer/pruner/dataset/cifar10.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, config, is_training):
3131
self.buffer_size = 10000
3232
self.num_samples_of_train = 50000
3333
self.num_samples_of_val = 10000
34+
self.data_shape = (32, 32, 3)
3435

3536
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
3637
def parse_fn(self, example_serialized):

src/model_optimizer/pruner/dataset/dataset_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def build(self, is_distill=False):
7878
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
7979
else:
8080
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
81-
return self.__build_batch(dataset)
81+
return self.build_batch(dataset)
8282

83-
def __build_batch(self, dataset):
83+
def build_batch(self, dataset):
8484
"""
8585
Make an batch from tf.data.Dataset.
8686
:param dataset: tf.data.Dataset object

src/model_optimizer/pruner/dataset/imagenet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
3232
self.buffer_size = 10000
3333
self.num_samples_of_train = 1281167
3434
self.num_samples_of_val = 50000
35+
self.data_shape = (224, 224, 3)
3536

3637
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
3738
def parse_fn(self, example_serialized):
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2019 ZTE corporation. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
ISCX 12 class session all dataset
6+
https://github.com/echowei/DeepTraffic/blob/master/2.encrypted_traffic_classification/3.PerprocessResults/12class.zip
7+
8+
"""
9+
import os
10+
import gzip
11+
import tensorflow as tf
12+
from tensorflow.python.keras.utils.data_utils import get_file
13+
import numpy as np
14+
from .dataset_base import DatasetBase
15+
16+
17+
class ISCXDataset(DatasetBase):
18+
"""
19+
ISCX session all layer dataset
20+
"""
21+
def __init__(self, config, is_training, num_shards=1, shard_index=0):
22+
"""
23+
Constructor function.
24+
:param config: Config object
25+
:param is_training: whether to construct the training subset
26+
:return:
27+
"""
28+
super().__init__(config, is_training, num_shards, shard_index)
29+
if is_training:
30+
self.batch_size = self.batch_size
31+
else:
32+
self.batch_size = self.batch_size_val
33+
self.buffer_size = 5000
34+
self.num_samples_of_train = 35501
35+
self.num_samples_of_val = 3945
36+
self.data_shape = (1, 784, 1)
37+
38+
# pylint: disable=R0201
39+
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
40+
def parse_fn(self, *content):
41+
data, label = content
42+
return data, label
43+
44+
def parse_fn_distill(self, *content):
45+
"""
46+
Parse dataset for distillation
47+
:param content: item content of the dataset
48+
:return: {image, label},{}
49+
"""
50+
image, label = self.parse_fn(*content)
51+
inputs = {"image": image, "label": label}
52+
targets = {}
53+
return inputs, targets
54+
55+
def build(self, is_distill=False):
56+
"""
57+
Build dataset
58+
:param is_distill: is distilling or not
59+
:return: batch of a dataset
60+
"""
61+
if self.is_training:
62+
x_path = os.path.join(self.data_dir, 'train-images-idx3-ubyte.gz')
63+
y_path = os.path.join(self.data_dir, 'train-labels-idx1-ubyte.gz')
64+
else:
65+
x_path = os.path.join(self.data_dir, 't10k-images-idx3-ubyte.gz')
66+
y_path = os.path.join(self.data_dir, 't10k-labels-idx1-ubyte.gz')
67+
68+
with gzip.open(y_path, 'rb') as lbpath:
69+
y_data = np.frombuffer(lbpath.read(), np.uint8, offset=8)
70+
71+
with gzip.open(x_path, 'rb') as imgpath:
72+
x_data = np.frombuffer(
73+
imgpath.read(), np.uint8, offset=16).reshape(len(y_data), 1, 784)
74+
75+
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
76+
77+
if self.num_shards != 1:
78+
dataset = dataset.shard(num_shards=self.num_shards, index=self.shard_index)
79+
if self.is_training:
80+
dataset = dataset.shuffle(buffer_size=self.buffer_size).repeat()
81+
if is_distill:
82+
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
83+
else:
84+
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
85+
return self.build_batch(dataset)

src/model_optimizer/pruner/dataset/mnist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, config, is_training):
3131
self.buffer_size = 10000
3232
self.num_samples_of_train = 60000
3333
self.num_samples_of_val = 10000
34+
self.data_shape = (28, 28, 1)
3435

3536
# pylint: disable=R0201
3637
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg

0 commit comments

Comments
 (0)