Skip to content

Commit 9e5c2e2

Browse files
committed
Fix var bugs
1 parent 5fa3a02 commit 9e5c2e2

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

visualdl/component/graph/graph_component.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,29 @@ def get_sub_ops(op, op_name, all_ops, all_vars):
464464
all_ops[sub_op_name]['is_leaf_node'] = True
465465
now_var = utils.gen_var_name(sub_op.results())
466466
for source in sub_op.operands_source():
467+
if str(source.type()) == '<<NULL TYPE>>':
468+
continue
467469
input_name = utils.gen_var_name(source)
470+
if input_name not in all_vars.keys():
471+
all_vars[input_name] = {}
472+
all_vars[input_name]['name'] = input_name
473+
try:
474+
attrs = source.results()[0].get_defining_op().attrs()
475+
if 'place' in attrs:
476+
attrs['place'] = str(attrs['place'])
477+
attrs['dtype'] = safe_get_dtype(source)
478+
except Exception:
479+
attrs = {}
480+
481+
all_vars[input_name]['shape'] = safe_get_shape(source)
482+
all_vars[input_name]['type'] = safe_get_type(source)
483+
all_vars[input_name]['dtype'] = safe_get_dtype(source)
484+
all_vars[input_name]['value'] = []
485+
all_vars[input_name]['persistable'] = safe_get_persistable(source)
486+
all_vars[input_name]['attrs'] = attrs
487+
all_vars[input_name]['from_node'] = ''
488+
all_vars[input_name]['to_nodes'] = []
489+
468490
if sub_op.name() == "pd_op.increment_":
469491
all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node'])
470492
all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var]
@@ -633,7 +655,29 @@ def analyse_pir(program):
633655
all_ops[op_name]['is_leaf_node'] = True
634656
now_var = utils.gen_var_name(op.results())
635657
for source in op.operands_source():
658+
if str(source.type()) == '<<NULL TYPE>>':
659+
continue
636660
input_name = utils.gen_var_name(source)
661+
if input_name not in all_vars.keys():
662+
all_vars[input_name] = {}
663+
all_vars[input_name]['name'] = input_name
664+
try:
665+
attrs = source.results()[0].get_defining_op().attrs()
666+
if 'place' in attrs:
667+
attrs['place'] = str(attrs['place'])
668+
attrs['dtype'] = safe_get_dtype(source)
669+
except Exception:
670+
attrs = {}
671+
672+
all_vars[input_name]['shape'] = safe_get_shape(source)
673+
all_vars[input_name]['type'] = safe_get_type(source)
674+
all_vars[input_name]['dtype'] = safe_get_dtype(source)
675+
all_vars[input_name]['value'] = []
676+
all_vars[input_name]['persistable'] = safe_get_persistable(source)
677+
all_vars[input_name]['attrs'] = attrs
678+
all_vars[input_name]['from_node'] = ''
679+
all_vars[input_name]['to_nodes'] = []
680+
637681
if op.name() == "pd_op.increment_":
638682
all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node'])
639683
all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var]

0 commit comments

Comments
 (0)