Skip to content

Commit b05d667

Browse files
authored
Merge pull request #211 from sassoftware/model_upload
feat: upload local model without any extra file generation (PMMODEL-682)
2 parents adffc94 + 6ac153b commit b05d667

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

src/sasctl/tasks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
from typing import Union
1717
from warnings import warn
18+
import zipfile
1819

1920
import pandas as pd
2021

@@ -998,3 +999,55 @@ def score_model_with_cas(
998999
print(score_execution_poll)
9991000
score_results = se.get_score_execution_results(score_execution, use_cas_gateway)
10001001
return score_results
1002+
1003+
1004+
def upload_local_model(
1005+
path: Union[str, Path],
1006+
model_name: str,
1007+
project_name: str,
1008+
repo_name: Union[str, dict] = None,
1009+
version: str = "latest",
1010+
):
1011+
"""A barebones function to upload a model and any associated files to the model repository.
1012+
Parameters
1013+
----------
1014+
path : Union[str, Path]
1015+
The path to the model and any associated files.
1016+
model_name : str
1017+
The name of the model.
1018+
project_name : str
1019+
The name of the project to which the model will be uploaded.
1020+
"""
1021+
# Use default repository if not specified
1022+
try:
1023+
if repo_name is None:
1024+
repository = mr.default_repository()
1025+
else:
1026+
repository = mr.get_repository(repo_name)
1027+
except HTTPError as e:
1028+
if e.code == 403:
1029+
raise AuthorizationError(
1030+
"Unable to register model. User account does not have read permissions "
1031+
"for the /modelRepository/repositories/ URL. Please contact your SAS "
1032+
"Viya administrator."
1033+
)
1034+
raise e
1035+
1036+
# Unable to find or create the repo.
1037+
if not repository and not repo_name:
1038+
raise ValueError("Unable to find a default repository")
1039+
elif not repository:
1040+
raise ValueError(f"Unable to find repository '{repo_name}'")
1041+
p = mr.get_project(project_name)
1042+
if p is None:
1043+
mr.create_project(project_name, repository)
1044+
zip_name = str(Path(path) / (model_name + ".zip"))
1045+
file_names = sorted(Path(path).glob("*[!zip]"))
1046+
with zipfile.ZipFile(str(zip_name), mode="w") as zFile:
1047+
for file in file_names:
1048+
zFile.write(str(file), arcname=file.name)
1049+
with open(zip_name, "rb") as zip_file:
1050+
model = mr.import_model_from_zip(
1051+
model_name, project_name, zip_file, version=version
1052+
)
1053+
return model

0 commit comments

Comments
 (0)