17
17
import pathlib
18
18
import re
19
19
20
+ from . import utils
21
+
20
22
_graph_version = '1.0.0'
21
23
22
24
@@ -73,9 +75,8 @@ def create_non_leaf_nodes(parent_node_name, child_node_name, all_ops,
73
75
if parent_node_name == '/' : # root node
74
76
return
75
77
else :
76
- create_non_leaf_nodes (
77
- os .path .dirname (parent_node_name ), parent_node_name , all_ops ,
78
- general_children_dict )
78
+ create_non_leaf_nodes (os .path .dirname (parent_node_name ),
79
+ parent_node_name , all_ops , general_children_dict )
79
80
80
81
81
82
def construct_edges (var_name , all_ops , all_vars , all_edges ):
@@ -298,8 +299,8 @@ def analyse_model(model_pb): # noqa: C901
298
299
299
300
all_op_names = list (all_ops .keys ())
300
301
for op_name in all_op_names :
301
- create_non_leaf_nodes (
302
- os . path . dirname ( op_name ), op_name , all_ops , general_children_dict )
302
+ create_non_leaf_nodes (os . path . dirname ( op_name ), op_name , all_ops ,
303
+ general_children_dict )
303
304
304
305
# fill all non-leaf node's 'output_nodes' 'input_nodes' 'output_vars' 'input_vars'
305
306
# post-order traverse tree
@@ -345,8 +346,9 @@ def analyse_model(model_pb): # noqa: C901
345
346
for src_node , to_node in all_edges .keys ():
346
347
all_ops [src_node ]['edge_output_nodes' ].append (to_node )
347
348
all_ops [to_node ]['edge_input_nodes' ].append (src_node )
348
- all_edges [(src_node , to_node )]['vars' ] = list (
349
- all_edges [(src_node , to_node )]['vars' ])
349
+ all_edges [(src_node ,
350
+ to_node )]['vars' ] = list (all_edges [(src_node ,
351
+ to_node )]['vars' ])
350
352
if len (all_edges [(src_node , to_node )]['vars' ]) > 1 :
351
353
all_edges [(src_node , to_node )]['label' ] = str (
352
354
len (all_edges [(src_node , to_node )]['vars' ])) + ' tensors'
@@ -361,3 +363,96 @@ def analyse_model(model_pb): # noqa: C901
361
363
'edges' : list (all_edges .values ())
362
364
}
363
365
return final_data
366
+
367
+
368
+ def analyse_pir (program ):
369
+ from paddle .utils .unique_name import generate
370
+
371
+ all_ops = {}
372
+ all_vars = {}
373
+ all_edges = {}
374
+ # vars info
375
+ for op in (program .global_block ().ops ):
376
+ var_name = utils .gen_var_name (op .results ())
377
+ all_vars [var_name ] = {}
378
+ all_vars [var_name ]['name' ] = var_name
379
+ attrs = op .results ()[0 ].get_defining_op ().attrs ()
380
+
381
+ if 'place' in attrs :
382
+ attrs ['place' ] = str (attrs ['place' ])
383
+ attrs ['dtype' ] = op .result (0 ).dtype .name
384
+
385
+ all_vars [var_name ]['shape' ] = op .result (0 ).shape
386
+ all_vars [var_name ]['type' ] = op .result (0 ).dtype .name
387
+ all_vars [var_name ]['dtype' ] = op .result (0 ).dtype .name
388
+
389
+ all_vars [var_name ]['value' ] = []
390
+ all_vars [var_name ]['persistable' ] = op .result (0 ).is_persistable
391
+ all_vars [var_name ]['attrs' ] = attrs
392
+ all_vars [var_name ]['from_node' ] = ''
393
+ all_vars [var_name ]['to_nodes' ] = []
394
+
395
+ # ops info
396
+ for op in (program .global_block ().ops ):
397
+ op_name = generate (op .name ())
398
+
399
+ if op .num_operands () > 0 :
400
+ all_ops [op_name ] = {}
401
+ all_ops [op_name ]['name' ] = op_name
402
+ all_ops [op_name ]['show_name' ] = op_name
403
+ all_ops [op_name ]['type' ] = op .result (0 ).dtype .name
404
+ all_ops [op_name ]['dtype' ] = op .result (0 ).dtype .name
405
+
406
+ all_ops [op_name ]['input_vars' ] = {}
407
+ all_ops [op_name ]['output_vars' ] = {}
408
+
409
+ all_ops [op_name ]['is_leaf_node' ] = True
410
+ now_var = utils .gen_var_name (op .results ())
411
+ for source in op .operands_source ():
412
+ input_name = utils .gen_var_name (source )
413
+ all_ops [op_name ]['input_vars' ][input_name ] = [input_name ]
414
+ all_vars [input_name ]['to_nodes' ].append (op_name )
415
+ all_vars [now_var ]['from_node' ] = op_name
416
+ all_ops [op_name ]['output_vars' ][now_var ] = [now_var ]
417
+
418
+ all_ops [op_name ]['attrs' ] = attrs
419
+ all_ops [op_name ]['attr_types' ] = attrs
420
+ all_ops [op_name ]['children_node' ] = []
421
+ all_ops [op_name ]['input_nodes' ] = []
422
+ all_ops [op_name ]['output_nodes' ] = []
423
+ all_ops [op_name ]['edge_input_nodes' ] = []
424
+ all_ops [op_name ]['edge_output_nodes' ] = []
425
+
426
+ # create '/' op
427
+ all_ops ['/' ] = {}
428
+ all_ops ['/' ]['name' ] = '/'
429
+ all_ops ['/' ]['show_name' ] = '/'
430
+ all_ops ['/' ]['type' ] = ''
431
+ all_ops ['/' ]['attrs' ] = {}
432
+ all_ops ['/' ]['input_vars' ] = {}
433
+ all_ops ['/' ]['output_vars' ] = {}
434
+ all_ops ['/' ]['is_leaf_node' ] = False
435
+ all_ops ['/' ]['children_node' ] = []
436
+ for node in all_ops :
437
+ if node != '/' :
438
+ all_ops ['/' ]['children_node' ].append (node )
439
+
440
+ for variable_name in all_vars :
441
+ if all_vars [variable_name ]['from_node' ] == '' :
442
+ continue
443
+ from_node_name = all_vars [variable_name ]['from_node' ]
444
+ for to_node_name in all_vars [variable_name ]['to_nodes' ]:
445
+ if to_node_name != from_node_name :
446
+ all_ops [from_node_name ]['output_nodes' ].append (to_node_name )
447
+ all_ops [to_node_name ]['input_nodes' ].append (from_node_name )
448
+
449
+ # edge info
450
+ # TODO(Difers):add edge info in future
451
+
452
+ final_data = {
453
+ 'version' : _graph_version ,
454
+ 'nodes' : list (all_ops .values ()),
455
+ 'vars' : list (all_vars .values ()),
456
+ 'edges' : list (all_edges .values ())
457
+ }
458
+ return final_data
0 commit comments