@@ -57,8 +57,8 @@ def plot(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(
5757 elif ('tree' in modelname ) or ('forest' in modelname ) or ('gradientboosting' in modelname ):
5858 if verbose >= 4 : print ('tree plotting pipeline.' )
5959 ax = randomforest (model , featnames = featnames , num_trees = num_trees , figsize = figsize , verbose = verbose )
60- if ('lgb' in modelname ):
61- ax = plot_lgb (model , featnames = featnames , num_trees = num_trees , figsize = figsize , verbose = verbose )
60+ elif ('lgb' in modelname ):
61+ ax = lgbm (model , featnames = featnames , num_trees = num_trees , figsize = figsize , verbose = verbose )
6262 else :
6363 print ('[treeplot] >Model not recognized: %s' % (modelname ))
6464 ax = None
@@ -67,7 +67,7 @@ def plot(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(
6767
6868
6969# %% Plot tree
70- def plot_lgb (model , featnames = None , num_trees = None , figsize = (25 ,25 ), verbose = 3 ):
70+ def lgbm (model , featnames = None , num_trees = None , figsize = (25 ,25 ), verbose = 3 ):
7171 try :
7272 from lightgbm import plot_tree , plot_importance
7373 except :
0 commit comments