Skip to content

Commit 9229209

Browse files
authored
fix config for fluid.data (#37)
1 parent 3af2e21 commit 9229209

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

ppdet/data/data_feed.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def __init__(self,
453453
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
454454
'is_crowd'
455455
],
456-
image_shape=[None, 3, None, None],
456+
image_shape=[3, None, None],
457457
sample_transforms=[
458458
DecodeImage(to_rgb=True),
459459
RandomFlipImage(prob=0.5),
@@ -505,7 +505,7 @@ def __init__(self,
505505
COCO_VAL_IMAGE_DIR).__dict__,
506506
fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
507507
'gt_label', 'is_difficult'],
508-
image_shape=[None, 3, None, None],
508+
image_shape=[3, None, None],
509509
sample_transforms=[
510510
DecodeImage(to_rgb=True),
511511
NormalizeImage(mean=[0.485, 0.456, 0.406],
@@ -552,7 +552,7 @@ def __init__(self,
552552
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
553553
COCO_VAL_IMAGE_DIR).__dict__,
554554
fields=['image', 'im_info', 'im_id', 'im_shape'],
555-
image_shape=[None, 3, None, None],
555+
image_shape=[3, None, None],
556556
sample_transforms=[
557557
DecodeImage(to_rgb=True),
558558
NormalizeImage(mean=[0.485, 0.456, 0.406],
@@ -600,7 +600,7 @@ def __init__(self,
600600
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
601601
'is_crowd', 'gt_mask'
602602
],
603-
image_shape=[None, 3, None, None],
603+
image_shape=[3, None, None],
604604
sample_transforms=[
605605
DecodeImage(to_rgb=True),
606606
RandomFlipImage(prob=0.5, is_mask_flip=True),
@@ -646,7 +646,7 @@ def __init__(self,
646646
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
647647
COCO_VAL_IMAGE_DIR).__dict__,
648648
fields=['image', 'im_info', 'im_id', 'im_shape'],
649-
image_shape=[None, 3, None, None],
649+
image_shape=[3, None, None],
650650
sample_transforms=[
651651
DecodeImage(to_rgb=True),
652652
NormalizeImage(mean=[0.485, 0.456, 0.406],
@@ -698,7 +698,7 @@ def __init__(self,
698698
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
699699
COCO_VAL_IMAGE_DIR).__dict__,
700700
fields=['image', 'im_info', 'im_id', 'im_shape'],
701-
image_shape=[None, 3, None, None],
701+
image_shape=[3, None, None],
702702
sample_transforms=[
703703
DecodeImage(to_rgb=True),
704704
NormalizeImage(
@@ -743,7 +743,7 @@ class SSDTrainFeed(DataFeed):
743743
def __init__(self,
744744
dataset=VocDataSet().__dict__,
745745
fields=['image', 'gt_box', 'gt_label'],
746-
image_shape=[None, 3, 300, 300],
746+
image_shape=[3, 300, 300],
747747
sample_transforms=[
748748
DecodeImage(to_rgb=True, with_mixup=False),
749749
NormalizeBox(),
@@ -802,7 +802,7 @@ def __init__(
802802
dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__,
803803
fields=['image', 'im_shape', 'im_id', 'gt_box',
804804
'gt_label', 'is_difficult'],
805-
image_shape=[None, 3, 300, 300],
805+
image_shape=[3, 300, 300],
806806
sample_transforms=[
807807
DecodeImage(to_rgb=True, with_mixup=False),
808808
NormalizeBox(),
@@ -847,7 +847,7 @@ class SSDTestFeed(DataFeed):
847847
def __init__(self,
848848
dataset=SimpleDataSet(VOC_VAL_ANNOTATION).__dict__,
849849
fields=['image', 'im_id', 'im_shape'],
850-
image_shape=[None, 3, 300, 300],
850+
image_shape=[3, 300, 300],
851851
sample_transforms=[
852852
DecodeImage(to_rgb=True),
853853
ResizeImage(target_size=300, use_cv2=False, interp=1),
@@ -893,7 +893,7 @@ class YoloTrainFeed(DataFeed):
893893
def __init__(self,
894894
dataset=CocoDataSet().__dict__,
895895
fields=['image', 'gt_box', 'gt_label', 'gt_score'],
896-
image_shape=[None, 3, 608, 608],
896+
image_shape=[3, 608, 608],
897897
sample_transforms=[
898898
DecodeImage(to_rgb=True, with_mixup=True),
899899
MixupImage(alpha=1.5, beta=1.5),
@@ -955,7 +955,7 @@ def __init__(self,
955955
COCO_VAL_IMAGE_DIR).__dict__,
956956
fields=['image', 'im_size', 'im_id', 'gt_box',
957957
'gt_label', 'is_difficult'],
958-
image_shape=[None, 3, 608, 608],
958+
image_shape=[3, 608, 608],
959959
sample_transforms=[
960960
DecodeImage(to_rgb=True),
961961
ResizeImage(target_size=608, interp=2),
@@ -1013,7 +1013,7 @@ def __init__(self,
10131013
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
10141014
COCO_VAL_IMAGE_DIR).__dict__,
10151015
fields=['image', 'im_size', 'im_id'],
1016-
image_shape=[None, 3, 608, 608],
1016+
image_shape=[3, 608, 608],
10171017
sample_transforms=[
10181018
DecodeImage(to_rgb=True),
10191019
ResizeImage(target_size=608, interp=2),

ppdet/modeling/model_input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
def create_feed(feed, iterable=False, sub_prog_feed=False):
43-
image_shape = feed.image_shape
43+
image_shape = [None] + feed.image_shape
4444
feed_var_map = {var['name']: var for var in feed_var_def}
4545
feed_var_map['image'] = {
4646
'name': 'image',
@@ -98,14 +98,14 @@ def create_feed(feed, iterable=False, sub_prog_feed=False):
9898
'lod_level': 0
9999
}
100100
image_name_list.append(name)
101-
feed_var_map['im_info']['shape'] = [feed.num_scale * 3]
101+
feed_var_map['im_info']['shape'] = [None, feed.num_scale * 3]
102102
feed.fields = image_name_list + feed.fields[1:]
103103
if sub_prog_feed:
104104
box_names = ['bbox', 'bbox_flip']
105105
for box_name in box_names:
106106
sub_prog_feed = {
107107
'name': box_name,
108-
'shape': [6],
108+
'shape': [None, 6],
109109
'dtype': 'float32',
110110
'lod_level': 1
111111
}

0 commit comments

Comments
 (0)