Skip to content

Commit 5e02c46

Browse files
committed
adding setgraphviz and removing redundant code.
1 parent db0dc95 commit 5e02c46

File tree

3 files changed

+88
-84
lines changed

3 files changed

+88
-84
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
with open("README.md", "r", encoding='utf8') as fh:
1414
long_description = fh.read()
1515
setuptools.setup(
16-
install_requires=['scikit-learn','numpy','graphviz>=0.20.1','matplotlib','wget','funcsigs'],
16+
install_requires=['scikit-learn','numpy','graphviz>=0.20.1','matplotlib', 'funcsigs'],
1717
python_requires='>=3',
1818
name='treeplot',
1919
version=new_version,

treeplot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
__author__ = 'Erdogan Tasksen'
1010
__email__ = 'erdogant@gmail.com'
11-
__version__ = '0.1.17'
11+
__version__ = '0.1.18'
1212

1313
# module level doc-string
1414
__doc__ = """

treeplot/treeplot.py

Lines changed: 86 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import matplotlib.image as mpimg
2020
import matplotlib.pyplot as plt
2121
from graphviz import Source
22-
import wget
22+
from setgraphviz import setgraphviz
23+
2324
URL = 'https://erdogant.github.io/datasets/graphviz-2.38.zip'
2425

2526

@@ -78,7 +79,8 @@ def lgbm(model, featnames=None, num_trees=None, figsize=(25,25), verbose=3):
7879
# Check model
7980
_check_model(model, 'lgb')
8081
# Set env
81-
_set_graphviz_path()
82+
# _set_graphviz_path()
83+
setgraphviz()
8284

8385
if (num_trees is None) and hasattr(model, 'best_iteration_'):
8486
num_trees = model.best_iteration_
@@ -137,7 +139,8 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
137139

138140
_check_model(model, 'xgb')
139141
# Set env
140-
_set_graphviz_path()
142+
# _set_graphviz_path()
143+
setgraphviz()
141144

142145
if plottype=='horizontal': plottype='UD'
143146
if plottype=='vertical': plottype='LR'
@@ -206,7 +209,8 @@ def randomforest(model, featnames=None, num_trees=None, filepath='tree', export=
206209
# Check model
207210
_check_model(model, 'randomforest')
208211
# Set env
209-
_set_graphviz_path()
212+
# _set_graphviz_path()
213+
setgraphviz()
210214

211215
if export is not None:
212216
dotfile = filepath + '.dot'
@@ -300,56 +304,56 @@ def import_example(data='random', n_samples=1000, n_feat=10):
300304

301305

302306
# %% Get graphiz path and include into local PATH
303-
def _set_graphviz_path(verbose=3):
304-
finPath=''
305-
if _get_platform()=="windows":
306-
# Download from github
307-
[gfile, curpath] = _download_graphviz(URL, verbose=verbose)
308-
309-
# curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
310-
# filesindir = os.listdir(curpath)[0]
311-
idx = gfile[::-1].find('.') + 1
312-
dirname = gfile[:-idx]
313-
getPath = os.path.abspath(os.path.join(curpath, dirname))
314-
getZip = os.path.abspath(os.path.join(curpath, gfile))
315-
# Unzip if path does not exists
316-
if not os.path.isdir(getPath):
317-
if verbose>=3: print('[treeplot] >Extracting graphviz files..')
318-
[pathname, _] = os.path.split(getZip)
319-
# Unzip
320-
zip_ref = zipfile.ZipFile(getZip, 'r')
321-
zip_ref.extractall(pathname)
322-
zip_ref.close()
323-
getPath = os.path.join(pathname, dirname)
324-
325-
# Point directly to the bin
326-
finPath = os.path.abspath(os.path.join(getPath, 'release', 'bin'))
327-
else:
328-
pass
329-
# sudo apt install python-pydot python-pydot-ng graphviz
330-
# dpkg -l | grep graphviz
331-
# call(['dpkg', '-l', 'grep', 'graphviz'])
332-
# call(['dpkg', '-s', 'graphviz'])
333-
334-
# Add to system
335-
if finPath not in os.environ["PATH"]:
336-
if verbose>=3: print('[treeplot] >Set path in environment.')
337-
os.environ["PATH"] += os.pathsep + finPath
338-
339-
return(finPath)
307+
# def _set_graphviz_path(verbose=3):
308+
# finPath=''
309+
# if _get_platform()=="windows":
310+
# # Download from github
311+
# [gfile, curpath] = _download_graphviz(URL, verbose=verbose)
312+
313+
# # curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
314+
# # filesindir = os.listdir(curpath)[0]
315+
# idx = gfile[::-1].find('.') + 1
316+
# dirname = gfile[:-idx]
317+
# getPath = os.path.abspath(os.path.join(curpath, dirname))
318+
# getZip = os.path.abspath(os.path.join(curpath, gfile))
319+
# # Unzip if path does not exists
320+
# if not os.path.isdir(getPath):
321+
# if verbose>=3: print('[treeplot] >Extracting graphviz files..')
322+
# [pathname, _] = os.path.split(getZip)
323+
# # Unzip
324+
# zip_ref = zipfile.ZipFile(getZip, 'r')
325+
# zip_ref.extractall(pathname)
326+
# zip_ref.close()
327+
# getPath = os.path.join(pathname, dirname)
328+
329+
# # Point directly to the bin
330+
# finPath = os.path.abspath(os.path.join(getPath, 'release', 'bin'))
331+
# else:
332+
# pass
333+
# # sudo apt install python-pydot python-pydot-ng graphviz
334+
# # dpkg -l | grep graphviz
335+
# # call(['dpkg', '-l', 'grep', 'graphviz'])
336+
# # call(['dpkg', '-s', 'graphviz'])
337+
338+
# # Add to system
339+
# if finPath not in os.environ["PATH"]:
340+
# if verbose>=3: print('[treeplot] >Set path in environment.')
341+
# os.environ["PATH"] += os.pathsep + finPath
342+
343+
# return(finPath)
340344

341345

342346
# %%
343-
def _get_platform():
344-
platforms = {
345-
'linux1':'linux',
346-
'linux2':'linux',
347-
'darwin':'osx',
348-
'win32':'windows'
349-
}
350-
if sys.platform not in platforms:
351-
return sys.platform
352-
return platforms[sys.platform]
347+
# def _get_platform():
348+
# platforms = {
349+
# 'linux1':'linux',
350+
# 'linux2':'linux',
351+
# 'darwin':'osx',
352+
# 'win32':'windows'
353+
# }
354+
# if sys.platform not in platforms:
355+
# return sys.platform
356+
# return platforms[sys.platform]
353357

354358

355359
# %% Check input model
@@ -368,34 +372,34 @@ def _check_model(model, expected):
368372
print('[treeplot] >Warning: The input model seems not to be a lightgbm model?')
369373

370374
# %% Import example dataset from github.
371-
def _download_graphviz(url, verbose=3):
372-
"""Import example dataset from github.
373-
374-
Parameters
375-
----------
376-
url : str, optional
377-
url-Link to graphviz. The default is 'https://erdogant.github.io/datasets/graphviz-2.38.zip'.
378-
verbose : int, optional
379-
Print message to screen. The default is 3.
380-
381-
Returns
382-
-------
383-
tuple : (gfile, curpath).
384-
gfile : filename
385-
curpath : currentpath
386-
387-
"""
388-
curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
389-
gfile = wget.filename_from_url(url)
390-
PATH_TO_DATA = os.path.join(curpath, gfile)
391-
if not os.path.isdir(curpath):
392-
if verbose>=3: print('[treeplot] >Downloading graphviz..')
393-
os.makedirs(curpath, exist_ok=True)
394-
395-
# Check file exists.
396-
if not os.path.isfile(PATH_TO_DATA):
397-
# Download data from URL
398-
if verbose>=3: print('[treeplot] >Downloading graphviz..')
399-
wget.download(url, curpath)
400-
401-
return(gfile, curpath)
375+
# def _download_graphviz(url, verbose=3):
376+
# """Import example dataset from github.
377+
378+
# Parameters
379+
# ----------
380+
# url : str, optional
381+
# url-Link to graphviz. The default is 'https://erdogant.github.io/datasets/graphviz-2.38.zip'.
382+
# verbose : int, optional
383+
# Print message to screen. The default is 3.
384+
385+
# Returns
386+
# -------
387+
# tuple : (gfile, curpath).
388+
# gfile : filename
389+
# curpath : currentpath
390+
391+
# """
392+
# curpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'RESOURCES')
393+
# gfile = wget.filename_from_url(url)
394+
# PATH_TO_DATA = os.path.join(curpath, gfile)
395+
# if not os.path.isdir(curpath):
396+
# if verbose>=3: print('[treeplot] >Downloading graphviz..')
397+
# os.makedirs(curpath, exist_ok=True)
398+
399+
# # Check file exists.
400+
# if not os.path.isfile(PATH_TO_DATA):
401+
# # Download data from URL
402+
# if verbose>=3: print('[treeplot] >Downloading graphviz..')
403+
# wget.download(url, curpath)
404+
405+
# return(gfile, curpath)

0 commit comments

Comments
 (0)