Skip to content

Commit ecb675b

Browse files
authored
Client transpose (#44)
add client option to set transposition direction
1 parent a015cbe commit ecb675b

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

example_client/grpc_serving_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@
3131
parser.add_argument('--grpc_address',required=False, default='localhost', help='Specify url to grpc service. default:localhost')
3232
parser.add_argument('--grpc_port',required=False, default=9000, help='Specify port to grpc service. default: 9000')
3333
parser.add_argument('--input_name',required=False, default='input', help='Specify input tensor name. default: input')
34-
parser.add_argument('--output_name',required=False, default='resnet_v1_50/predictions/Reshape_1', help='Specify output name. default: resnet_v1_50/predictions/Reshape_1')
34+
parser.add_argument('--output_name',required=False, default='resnet_v1_50/predictions/Reshape_1',
35+
help='Specify output name. default: resnet_v1_50/predictions/Reshape_1')
3536
parser.add_argument('--transpose_input', choices=["False", "True"], default="True",
36-
help='Set to False to skip NHWC->NCHW input transposing. default: True',
37+
help='Set to False to skip NHWC>NCHW or NCHW>NHWC input transposing. default: True',
3738
dest="transpose_input")
39+
parser.add_argument('--transpose_method', choices=["nchw2nhwc","nhwc2nchw"], default="nhwc2nchw",
40+
help="How the input transposition should be executed: nhwc2nchw or nhwc2nchw",
41+
dest="transpose_method")
3842
parser.add_argument('--iterations', default=0,
3943
help='Number of requests iterations, as default use number of images in numpy memmap. default: 0 (consume all frames)',
4044
dest='iterations', type=int)
@@ -78,10 +82,13 @@
7882
print('\tModel name: {}'.format(args.get('model_name')))
7983
print('\tIterations: {}'.format(iterations))
8084
print('\tImages numpy path: {}'.format(args.get('images_numpy_path')))
85+
if args.get('transpose_input') == "True":
86+
if args.get('transpose_method') == "nhwc2nchw":
87+
imgs = imgs.transpose((0,3,1,2))
88+
if args.get('transpose_method') == "nchw2nhwc":
89+
imgs = imgs.transpose((0,2,3,1))
8190
print('\tImages in shape: {}\n'.format(imgs.shape))
8291

83-
if args.get('transpose_input') == "True":
84-
imgs = imgs.transpose((0,3,1,2))
8592
iteration = 0
8693

8794
while iteration <= iterations:

0 commit comments

Comments
 (0)