@@ -464,7 +464,29 @@ def get_sub_ops(op, op_name, all_ops, all_vars):
464
464
all_ops [sub_op_name ]['is_leaf_node' ] = True
465
465
now_var = utils .gen_var_name (sub_op .results ())
466
466
for source in sub_op .operands_source ():
467
+ if str (source .type ()) == '<<NULL TYPE>>' :
468
+ continue
467
469
input_name = utils .gen_var_name (source )
470
+ if input_name not in all_vars .keys ():
471
+ all_vars [input_name ] = {}
472
+ all_vars [input_name ]['name' ] = input_name
473
+ try :
474
+ attrs = source .results ()[0 ].get_defining_op ().attrs ()
475
+ if 'place' in attrs :
476
+ attrs ['place' ] = str (attrs ['place' ])
477
+ attrs ['dtype' ] = safe_get_dtype (source )
478
+ except Exception :
479
+ attrs = {}
480
+
481
+ all_vars [input_name ]['shape' ] = safe_get_shape (source )
482
+ all_vars [input_name ]['type' ] = safe_get_type (source )
483
+ all_vars [input_name ]['dtype' ] = safe_get_dtype (source )
484
+ all_vars [input_name ]['value' ] = []
485
+ all_vars [input_name ]['persistable' ] = safe_get_persistable (source )
486
+ all_vars [input_name ]['attrs' ] = attrs
487
+ all_vars [input_name ]['from_node' ] = ''
488
+ all_vars [input_name ]['to_nodes' ] = []
489
+
468
490
if sub_op .name () == "pd_op.increment_" :
469
491
all_vars [now_var ]['to_nodes' ].append (all_vars [input_name ]['from_node' ])
470
492
all_ops [all_vars [input_name ]['from_node' ]]['input_vars' ][now_var ] = [now_var ]
@@ -633,7 +655,29 @@ def analyse_pir(program):
633
655
all_ops [op_name ]['is_leaf_node' ] = True
634
656
now_var = utils .gen_var_name (op .results ())
635
657
for source in op .operands_source ():
658
+ if str (source .type ()) == '<<NULL TYPE>>' :
659
+ continue
636
660
input_name = utils .gen_var_name (source )
661
+ if input_name not in all_vars .keys ():
662
+ all_vars [input_name ] = {}
663
+ all_vars [input_name ]['name' ] = input_name
664
+ try :
665
+ attrs = source .results ()[0 ].get_defining_op ().attrs ()
666
+ if 'place' in attrs :
667
+ attrs ['place' ] = str (attrs ['place' ])
668
+ attrs ['dtype' ] = safe_get_dtype (source )
669
+ except Exception :
670
+ attrs = {}
671
+
672
+ all_vars [input_name ]['shape' ] = safe_get_shape (source )
673
+ all_vars [input_name ]['type' ] = safe_get_type (source )
674
+ all_vars [input_name ]['dtype' ] = safe_get_dtype (source )
675
+ all_vars [input_name ]['value' ] = []
676
+ all_vars [input_name ]['persistable' ] = safe_get_persistable (source )
677
+ all_vars [input_name ]['attrs' ] = attrs
678
+ all_vars [input_name ]['from_node' ] = ''
679
+ all_vars [input_name ]['to_nodes' ] = []
680
+
637
681
if op .name () == "pd_op.increment_" :
638
682
all_vars [now_var ]['to_nodes' ].append (all_vars [input_name ]['from_node' ])
639
683
all_ops [all_vars [input_name ]['from_node' ]]['input_vars' ][now_var ] = [now_var ]
0 commit comments