Skip to content

Commit b974acc

Browse files
committed
Resolved DALI Bug: Fixed issue #2235
DALI's inability to read InsightFace style rec by implementing the script 'scripts/shuffle_rec.py' to generate shuffled recs.
1 parent cff32e0 commit b974acc

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

recognition/arcface_torch/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 -
3434
```
3535

3636
Node 1:
37-
37+
3838
```shell
3939
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
4040
```
@@ -52,6 +52,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 -
5252
- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
5353
- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
5454

55+
56+
Note:
57+
If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
58+
Example:
59+
60+
`python scripts/shuffle_rec.py ms1m-retinaface-t1`
61+
62+
You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
63+
64+
5565
## Model Zoo
5666

5767
- The models are available for non-commercial research purposes only.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import argparse
2+
import multiprocessing
3+
import os
4+
import time
5+
6+
import mxnet as mx
7+
import numpy as np
8+
9+
10+
def read_worker(args, q_in):
11+
path_imgidx = os.path.join(args.input, "train.idx")
12+
path_imgrec = os.path.join(args.input, "train.rec")
13+
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
14+
15+
s = imgrec.read_idx(0)
16+
header, _ = mx.recordio.unpack(s)
17+
assert header.flag > 0
18+
19+
imgidx = np.array(range(1, int(header.label[0])))
20+
np.random.shuffle(imgidx)
21+
22+
for idx in imgidx:
23+
item = imgrec.read_idx(idx)
24+
q_in.put(item)
25+
26+
q_in.put(None)
27+
imgrec.close()
28+
29+
30+
def write_worker(args, q_out):
31+
pre_time = time.time()
32+
33+
if args.input[-1] == '/':
34+
args.input = args.input[:-1]
35+
dirname = os.path.dirname(args.input)
36+
basename = os.path.basename(args.input)
37+
output = os.path.join(dirname, f"shuffled_{basename}")
38+
os.makedirs(output, exist_ok=True)
39+
40+
path_imgidx = os.path.join(output, "train.idx")
41+
path_imgrec = os.path.join(output, "train.rec")
42+
save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w")
43+
more = True
44+
count = 0
45+
while more:
46+
deq = q_out.get()
47+
if deq is None:
48+
more = False
49+
else:
50+
header, jpeg = mx.recordio.unpack(deq)
51+
# TODO it is currently not fully developed
52+
if isinstance(header.label, float):
53+
label = header.label
54+
else:
55+
label = header.label[0]
56+
57+
header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2)
58+
save_record.write_idx(count, mx.recordio.pack(header, jpeg))
59+
count += 1
60+
if count % 10000 == 0:
61+
cur_time = time.time()
62+
print('save time:', cur_time - pre_time, ' count:', count)
63+
pre_time = cur_time
64+
print(count)
65+
save_record.close()
66+
67+
68+
def main(args):
69+
queue = multiprocessing.Queue(10240)
70+
read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
71+
read_process.daemon = True
72+
read_process.start()
73+
write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
74+
write_process.start()
75+
write_process.join()
76+
77+
78+
if __name__ == '__main__':
79+
parser = argparse.ArgumentParser()
80+
parser.add_argument('input', help='path to source rec.')
81+
main(parser.parse_args())

0 commit comments

Comments
 (0)