1919import matplotlib .image as mpimg
2020import matplotlib .pyplot as plt
2121from graphviz import Source
22- import wget
22+ from setgraphviz import setgraphviz
23+
2324URL = 'https://erdogant.github.io/datasets/graphviz-2.38.zip'
2425
2526
@@ -78,7 +79,8 @@ def lgbm(model, featnames=None, num_trees=None, figsize=(25,25), verbose=3):
7879 # Check model
7980 _check_model (model , 'lgb' )
8081 # Set env
81- _set_graphviz_path ()
82+ # _set_graphviz_path()
83+ setgraphviz ()
8284
8385 if (num_trees is None ) and hasattr (model , 'best_iteration_' ):
8486 num_trees = model .best_iteration_
@@ -137,7 +139,8 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
137139
138140 _check_model (model , 'xgb' )
139141 # Set env
140- _set_graphviz_path ()
142+ # _set_graphviz_path()
143+ setgraphviz ()
141144
142145 if plottype == 'horizontal' : plottype = 'UD'
143146 if plottype == 'vertical' : plottype = 'LR'
@@ -206,7 +209,8 @@ def randomforest(model, featnames=None, num_trees=None, filepath='tree', export=
206209 # Check model
207210 _check_model (model , 'randomforest' )
208211 # Set env
209- _set_graphviz_path ()
212+ # _set_graphviz_path()
213+ setgraphviz ()
210214
211215 if export is not None :
212216 dotfile = filepath + '.dot'
@@ -300,56 +304,56 @@ def import_example(data='random', n_samples=1000, n_feat=10):
300304
301305
302306# %% Get graphiz path and include into local PATH
303- def _set_graphviz_path (verbose = 3 ):
304- finPath = ''
305- if _get_platform ()== "windows" :
306- # Download from github
307- [gfile , curpath ] = _download_graphviz (URL , verbose = verbose )
308-
309- # curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
310- # filesindir = os.listdir(curpath)[0]
311- idx = gfile [::- 1 ].find ('.' ) + 1
312- dirname = gfile [:- idx ]
313- getPath = os .path .abspath (os .path .join (curpath , dirname ))
314- getZip = os .path .abspath (os .path .join (curpath , gfile ))
315- # Unzip if path does not exists
316- if not os .path .isdir (getPath ):
317- if verbose >= 3 : print ('[treeplot] >Extracting graphviz files..' )
318- [pathname , _ ] = os .path .split (getZip )
319- # Unzip
320- zip_ref = zipfile .ZipFile (getZip , 'r' )
321- zip_ref .extractall (pathname )
322- zip_ref .close ()
323- getPath = os .path .join (pathname , dirname )
324-
325- # Point directly to the bin
326- finPath = os .path .abspath (os .path .join (getPath , 'release' , 'bin' ))
327- else :
328- pass
329- # sudo apt install python-pydot python-pydot-ng graphviz
330- # dpkg -l | grep graphviz
331- # call(['dpkg', '-l', 'grep', 'graphviz'])
332- # call(['dpkg', '-s', 'graphviz'])
333-
334- # Add to system
335- if finPath not in os .environ ["PATH" ]:
336- if verbose >= 3 : print ('[treeplot] >Set path in environment.' )
337- os .environ ["PATH" ] += os .pathsep + finPath
338-
339- return (finPath )
307+ # def _set_graphviz_path(verbose=3):
308+ # finPath=''
309+ # if _get_platform()=="windows":
310+ # # Download from github
311+ # [gfile, curpath] = _download_graphviz(URL, verbose=verbose)
312+
313+ # # curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
314+ # # filesindir = os.listdir(curpath)[0]
315+ # idx = gfile[::-1].find('.') + 1
316+ # dirname = gfile[:-idx]
317+ # getPath = os.path.abspath(os.path.join(curpath, dirname))
318+ # getZip = os.path.abspath(os.path.join(curpath, gfile))
319+ # # Unzip if path does not exists
320+ # if not os.path.isdir(getPath):
321+ # if verbose>=3: print('[treeplot] >Extracting graphviz files..')
322+ # [pathname, _] = os.path.split(getZip)
323+ # # Unzip
324+ # zip_ref = zipfile.ZipFile(getZip, 'r')
325+ # zip_ref.extractall(pathname)
326+ # zip_ref.close()
327+ # getPath = os.path.join(pathname, dirname)
328+
329+ # # Point directly to the bin
330+ # finPath = os.path.abspath(os.path.join(getPath, 'release', 'bin'))
331+ # else:
332+ # pass
333+ # # sudo apt install python-pydot python-pydot-ng graphviz
334+ # # dpkg -l | grep graphviz
335+ # # call(['dpkg', '-l', 'grep', 'graphviz'])
336+ # # call(['dpkg', '-s', 'graphviz'])
337+
338+ # # Add to system
339+ # if finPath not in os.environ["PATH"]:
340+ # if verbose>=3: print('[treeplot] >Set path in environment.')
341+ # os.environ["PATH"] += os.pathsep + finPath
342+
343+ # return(finPath)
340344
341345
342346# %%
343- def _get_platform ():
344- platforms = {
345- 'linux1' :'linux' ,
346- 'linux2' :'linux' ,
347- 'darwin' :'osx' ,
348- 'win32' :'windows'
349- }
350- if sys .platform not in platforms :
351- return sys .platform
352- return platforms [sys .platform ]
347+ # def _get_platform():
348+ # platforms = {
349+ # 'linux1':'linux',
350+ # 'linux2':'linux',
351+ # 'darwin':'osx',
352+ # 'win32':'windows'
353+ # }
354+ # if sys.platform not in platforms:
355+ # return sys.platform
356+ # return platforms[sys.platform]
353357
354358
355359# %% Check input model
@@ -368,34 +372,34 @@ def _check_model(model, expected):
368372 print ('[treeplot] >Warning: The input model seems not to be a lightgbm model?' )
369373
370374# %% Import example dataset from github.
371- def _download_graphviz (url , verbose = 3 ):
372- """Import example dataset from github.
373-
374- Parameters
375- ----------
376- url : str, optional
377- url-Link to graphviz. The default is 'https://erdogant.github.io/datasets/graphviz-2.38.zip'.
378- verbose : int, optional
379- Print message to screen. The default is 3.
380-
381- Returns
382- -------
383- tuple : (gfile, curpath).
384- gfile : filename
385- curpath : currentpath
386-
387- """
388- curpath = os .path .join (os .path .dirname (os .path .abspath (__file__ )), 'RESOURCES' )
389- gfile = wget .filename_from_url (url )
390- PATH_TO_DATA = os .path .join (curpath , gfile )
391- if not os .path .isdir (curpath ):
392- if verbose >= 3 : print ('[treeplot] >Downloading graphviz..' )
393- os .makedirs (curpath , exist_ok = True )
394-
395- # Check file exists.
396- if not os .path .isfile (PATH_TO_DATA ):
397- # Download data from URL
398- if verbose >= 3 : print ('[treeplot] >Downloading graphviz..' )
399- wget .download (url , curpath )
400-
401- return (gfile , curpath )
375+ # def _download_graphviz(url, verbose=3):
376+ # """Import example dataset from github.
377+
378+ # Parameters
379+ # ----------
380+ # url : str, optional
381+ # url-Link to graphviz. The default is 'https://erdogant.github.io/datasets/graphviz-2.38.zip'.
382+ # verbose : int, optional
383+ # Print message to screen. The default is 3.
384+
385+ # Returns
386+ # -------
387+ # tuple : (gfile, curpath).
388+ # gfile : filename
389+ # curpath : currentpath
390+
391+ # """
392+ # curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
393+ # gfile = wget.filename_from_url(url)
394+ # PATH_TO_DATA = os.path.join(curpath, gfile)
395+ # if not os.path.isdir(curpath):
396+ # if verbose>=3: print('[treeplot] >Downloading graphviz..')
397+ # os.makedirs(curpath, exist_ok=True)
398+
399+ # # Check file exists.
400+ # if not os.path.isfile(PATH_TO_DATA):
401+ # # Download data from URL
402+ # if verbose>=3: print('[treeplot] >Downloading graphviz..')
403+ # wget.download(url, curpath)
404+
405+ # return(gfile, curpath)
0 commit comments