Skip to content

Commit 6c9757b

Browse files
authored
[New Feature] Support Paddle PIR Visualization 🚀 (#1263)
* add new_ir visualization * refine some interface * fix pir to kwargs
1 parent e420b8c commit 6c9757b

File tree

7 files changed

+291
-157
lines changed

7 files changed

+291
-157
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- id: check-yaml
1212
- id: check-symlinks
1313
- id: destroyed-symlinks
14-
- repo: https://gitlab.com/pycqa/flake8
14+
- repo: https://github.com/pycqa/flake8
1515
rev: 3.8.4
1616
hooks:
1717
- id: flake8

demo/components/pir_translate.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import paddle
2+
from paddle import ir
3+
4+
from visualdl import LogWriter
5+
6+
paddle.enable_static()
7+
8+
main_program, start_program = (
9+
paddle.static.Program(),
10+
paddle.static.Program(),
11+
)
12+
with paddle.static.program_guard(main_program, start_program):
13+
x = paddle.static.data("x", [1, 64, 64, 8], dtype="float32")
14+
y = paddle.static.data("y", [1, 64, 64, 8], dtype="float32")
15+
divide_out = paddle.divide(x, y)
16+
tanh_out = paddle.tanh(divide_out)
17+
conv2d = paddle.nn.Conv2D(8, 32, 1, bias_attr=False, data_format='NHWC')
18+
batch_norm = paddle.nn.BatchNorm(32, act='relu', data_layout='NHWC')
19+
out = batch_norm(conv2d(tanh_out))
20+
21+
newir_program = ir.translate_to_new_ir(main_program.desc)
22+
23+
with LogWriter(logdir="./log/program_test/") as writer:
24+
writer.add_graph(
25+
model=newir_program,
26+
input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')],
27+
verbose=True,
28+
is_pir=True)

visualdl/component/graph/exporter.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@
1717
import tempfile
1818

1919
from .graph_component import analyse_model
20+
from .graph_component import analyse_pir
2021
from .utils import create_opname_scope
2122
from .utils import print_model
2223

2324

24-
def translate_graph(model, input_spec, verbose=True):
25-
import paddle
25+
def translate_graph(model, input_spec, verbose=True, **kwargs):
26+
is_pir = kwargs.get('is_pir', False)
2627
with tempfile.TemporaryDirectory() as tmp:
27-
model._full_name = '{}[{}]'.format(model.__class__.__name__, "model")
28-
create_opname_scope(model)
29-
model = paddle.jit.to_static(model, input_spec)
30-
paddle.jit.save(model, os.path.join(tmp, 'temp'))
31-
model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read()
32-
result = analyse_model(model_data)
28+
if (not is_pir):
29+
model._full_name = '{}[{}]'.format(model.__class__.__name__,
30+
"model")
31+
create_opname_scope(model)
32+
model = paddle.jit.to_static(model, input_spec)
33+
paddle.jit.save(model, os.path.join(tmp, 'temp'))
34+
model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read()
35+
result = analyse_model(model_data)
36+
else:
37+
result = analyse_pir(model)
3338
if verbose:
3439
print_model(result)
3540
result = json.dumps(result, indent=2)

visualdl/component/graph/graph_component.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import pathlib
1818
import re
1919

20+
from . import utils
21+
2022
_graph_version = '1.0.0'
2123

2224

@@ -73,9 +75,8 @@ def create_non_leaf_nodes(parent_node_name, child_node_name, all_ops,
7375
if parent_node_name == '/': # root node
7476
return
7577
else:
76-
create_non_leaf_nodes(
77-
os.path.dirname(parent_node_name), parent_node_name, all_ops,
78-
general_children_dict)
78+
create_non_leaf_nodes(os.path.dirname(parent_node_name),
79+
parent_node_name, all_ops, general_children_dict)
7980

8081

8182
def construct_edges(var_name, all_ops, all_vars, all_edges):
@@ -298,8 +299,8 @@ def analyse_model(model_pb): # noqa: C901
298299

299300
all_op_names = list(all_ops.keys())
300301
for op_name in all_op_names:
301-
create_non_leaf_nodes(
302-
os.path.dirname(op_name), op_name, all_ops, general_children_dict)
302+
create_non_leaf_nodes(os.path.dirname(op_name), op_name, all_ops,
303+
general_children_dict)
303304

304305
# fill all non-leaf node's 'output_nodes' 'input_nodes' 'output_vars' 'input_vars'
305306
# post-order traverse tree
@@ -345,8 +346,9 @@ def analyse_model(model_pb): # noqa: C901
345346
for src_node, to_node in all_edges.keys():
346347
all_ops[src_node]['edge_output_nodes'].append(to_node)
347348
all_ops[to_node]['edge_input_nodes'].append(src_node)
348-
all_edges[(src_node, to_node)]['vars'] = list(
349-
all_edges[(src_node, to_node)]['vars'])
349+
all_edges[(src_node,
350+
to_node)]['vars'] = list(all_edges[(src_node,
351+
to_node)]['vars'])
350352
if len(all_edges[(src_node, to_node)]['vars']) > 1:
351353
all_edges[(src_node, to_node)]['label'] = str(
352354
len(all_edges[(src_node, to_node)]['vars'])) + ' tensors'
@@ -361,3 +363,96 @@ def analyse_model(model_pb): # noqa: C901
361363
'edges': list(all_edges.values())
362364
}
363365
return final_data
366+
367+
368+
def analyse_pir(program):
369+
from paddle.utils.unique_name import generate
370+
371+
all_ops = {}
372+
all_vars = {}
373+
all_edges = {}
374+
# vars info
375+
for op in (program.global_block().ops):
376+
var_name = utils.gen_var_name(op.results())
377+
all_vars[var_name] = {}
378+
all_vars[var_name]['name'] = var_name
379+
attrs = op.results()[0].get_defining_op().attrs()
380+
381+
if 'place' in attrs:
382+
attrs['place'] = str(attrs['place'])
383+
attrs['dtype'] = op.result(0).dtype.name
384+
385+
all_vars[var_name]['shape'] = op.result(0).shape
386+
all_vars[var_name]['type'] = op.result(0).dtype.name
387+
all_vars[var_name]['dtype'] = op.result(0).dtype.name
388+
389+
all_vars[var_name]['value'] = []
390+
all_vars[var_name]['persistable'] = op.result(0).is_persistable
391+
all_vars[var_name]['attrs'] = attrs
392+
all_vars[var_name]['from_node'] = ''
393+
all_vars[var_name]['to_nodes'] = []
394+
395+
# ops info
396+
for op in (program.global_block().ops):
397+
op_name = generate(op.name())
398+
399+
if op.num_operands() > 0:
400+
all_ops[op_name] = {}
401+
all_ops[op_name]['name'] = op_name
402+
all_ops[op_name]['show_name'] = op_name
403+
all_ops[op_name]['type'] = op.result(0).dtype.name
404+
all_ops[op_name]['dtype'] = op.result(0).dtype.name
405+
406+
all_ops[op_name]['input_vars'] = {}
407+
all_ops[op_name]['output_vars'] = {}
408+
409+
all_ops[op_name]['is_leaf_node'] = True
410+
now_var = utils.gen_var_name(op.results())
411+
for source in op.operands_source():
412+
input_name = utils.gen_var_name(source)
413+
all_ops[op_name]['input_vars'][input_name] = [input_name]
414+
all_vars[input_name]['to_nodes'].append(op_name)
415+
all_vars[now_var]['from_node'] = op_name
416+
all_ops[op_name]['output_vars'][now_var] = [now_var]
417+
418+
all_ops[op_name]['attrs'] = attrs
419+
all_ops[op_name]['attr_types'] = attrs
420+
all_ops[op_name]['children_node'] = []
421+
all_ops[op_name]['input_nodes'] = []
422+
all_ops[op_name]['output_nodes'] = []
423+
all_ops[op_name]['edge_input_nodes'] = []
424+
all_ops[op_name]['edge_output_nodes'] = []
425+
426+
# create '/' op
427+
all_ops['/'] = {}
428+
all_ops['/']['name'] = '/'
429+
all_ops['/']['show_name'] = '/'
430+
all_ops['/']['type'] = ''
431+
all_ops['/']['attrs'] = {}
432+
all_ops['/']['input_vars'] = {}
433+
all_ops['/']['output_vars'] = {}
434+
all_ops['/']['is_leaf_node'] = False
435+
all_ops['/']['children_node'] = []
436+
for node in all_ops:
437+
if node != '/':
438+
all_ops['/']['children_node'].append(node)
439+
440+
for variable_name in all_vars:
441+
if all_vars[variable_name]['from_node'] == '':
442+
continue
443+
from_node_name = all_vars[variable_name]['from_node']
444+
for to_node_name in all_vars[variable_name]['to_nodes']:
445+
if to_node_name != from_node_name:
446+
all_ops[from_node_name]['output_nodes'].append(to_node_name)
447+
all_ops[to_node_name]['input_nodes'].append(from_node_name)
448+
449+
# edge info
450+
# TODO(Difers):add edge info in future
451+
452+
final_data = {
453+
'version': _graph_version,
454+
'nodes': list(all_ops.values()),
455+
'vars': list(all_vars.values()),
456+
'edges': list(all_edges.values())
457+
}
458+
return final_data

visualdl/component/graph/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@
1717

1818
_name_scope_stack = deque()
1919

20+
# TODO(Difers): remove it when the new IR's "name" interface is available.
21+
var_name = {}
22+
var_idx = [0]
23+
24+
25+
def gen_var_name(ops):
26+
if not isinstance(ops, list):
27+
ops = [ops]
28+
for op in ops:
29+
var = op.get_defining_op()
30+
if var in var_name:
31+
return var_name[var]
32+
else:
33+
try:
34+
name = op.name
35+
except ValueError:
36+
name = "tmp_var_" + str(var_idx[0])
37+
var_idx[0] += 1
38+
var_name[var] = name
39+
return var_name[var]
40+
2041

2142
def _opname_creation_prehook(layer, inputs):
2243
from paddle.static import name_scope

0 commit comments

Comments
 (0)