Skip to content

Commit 8af5b8f

Browse files
passing tests and example notebook
1 parent 742f730 commit 8af5b8f

File tree

2 files changed

+329
-198
lines changed

2 files changed

+329
-198
lines changed

src/sasctl/_services/model_repository.py

Lines changed: 128 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from warnings import warn
1111
import requests
1212
from requests.exceptions import HTTPError
13-
import traceback
13+
import urllib
1414

15-
from ..core import current_session, delete, get, sasctl_command
15+
# import traceback
16+
# import sys
17+
18+
from ..core import current_session, delete, get, sasctl_command, RestObj
1619
from .service import Service
1720

21+
1822
FUNCTIONS = {
1923
"Analytical",
2024
"Classification",
@@ -615,7 +619,7 @@ def list_model_versions(cls, model):
615619
616620
Returns
617621
-------
618-
RestObj
622+
list
619623
620624
"""
621625

@@ -625,41 +629,60 @@ def list_model_versions(cls, model):
625629
"Cannot find link for version history for model '%s'" % model
626630
)
627631

628-
629-
modelHistory = cls.request_link(
632+
modelHistory = cls.request_link(
630633
link,
631634
"modelHistory",
632635
headers={"Accept": "application/vnd.sas.collection+json"},
633636
)
634637

638+
if isinstance(modelHistory, RestObj):
639+
return [modelHistory]
635640
return modelHistory
636641

637642
@classmethod
638-
def get_model_version(cls, model, version_id): #check if this now handles a return 1 case
643+
def get_model_version(cls, model, version_id):
644+
"""Get a specific version of a model.
645+
646+
Parameters
647+
----------
648+
model : str or dict
649+
The name, id, or dictionary representation of a model.
650+
version_id: str
651+
The id of a model version.
652+
653+
Returns
654+
-------
655+
RestObj
656+
657+
"""
639658

640659
model_history = cls.list_model_versions(model)
641660

642661
for item in model_history:
643-
if isinstance(item, str):
644-
if item == 'id' and dict(model_history)[item] == version_id:
645-
return cls.request_link(
646-
model_history,
647-
"self",
648-
headers={"Accept": "application/vnd.sas.models.model.version"},
649-
)
650-
continue
651-
652662
if item["id"] == version_id:
653663
return cls.request_link(
654664
item,
655665
"self",
656-
headers={"Accept": "application/vnd.sas.models.model.version"},
666+
headers={"Accept": "application/vnd.sas.models.model.version+json"},
657667
)
658668

659669
raise ValueError("The version id specified could not be found.")
660670

661671
@classmethod
662672
def get_model_with_versions(cls, model):
673+
"""Get the current model with its version history.
674+
675+
Parameters
676+
----------
677+
model : str or dict
678+
The name, id, or dictionary representation of a model.
679+
680+
Returns
681+
-------
682+
list
683+
684+
"""
685+
663686
if cls.is_uuid(model):
664687
model_id = model
665688
elif isinstance(model, dict) and "id" in model:
@@ -672,75 +695,130 @@ def get_model_with_versions(cls, model):
672695
)
673696
model_id = model["id"]
674697

675-
versions_uri = f"/modelRepository/models/{model_id}/versions"
676-
version_history = cls.get(
677-
versions_uri,
678-
headers={"Accept": "application/vnd.sas.models.model.version"},
679-
)
680-
if version_history is None:
681-
return {}
698+
versions_uri = f"/models/{model_id}/versions"
699+
try:
700+
version_history = cls.request(
701+
"GET",
702+
versions_uri,
703+
headers={"Accept": "application/vnd.sas.collection+json"},
704+
)
705+
except urllib.error.HTTPError as e:
706+
raise HTTPError(
707+
f"Request failed: Model id may be referencing a non-existing model."
708+
) from None
709+
710+
if isinstance(version_history, RestObj):
711+
return [version_history]
712+
682713
return version_history
683714

684715
@classmethod
685716
def get_model_or_version(cls, model, version_id):
717+
"""Get a specific version of a model but if model id and version id are the same, the current model is returned.
686718
687-
if cls.is_uuid(model):
688-
model_id = model
689-
elif isinstance(model, dict) and "id" in model:
690-
model_id = model["id"]
691-
else:
692-
model = cls.get_model(model)
693-
if not model:
694-
raise HTTPError(
695-
"This model may not exist in a project or the model may not exist at all."
696-
)
697-
model_id = model["id"]
719+
Parameters
720+
----------
721+
model : str or dict
722+
The name, id, or dictionary representation of a model.
723+
version_id: str
724+
The id of a model version.
725+
726+
Returns
727+
-------
728+
RestObj
698729
699-
if model_id == version_id:
700-
return cls.get_model(model)
730+
"""
701731

702732
version_history = cls.get_model_with_versions(model)
703-
model_versions = version_history.get("modelVersions")
704-
for i, item in enumerate(model_versions):
705-
if item.get("id") == version_id:
733+
734+
for item in version_history:
735+
if item["id"] == version_id:
706736
return cls.request_link(
707737
item,
708738
"self",
709-
headers={"Accept": "application/vnd.sas.models.model.version"},
739+
headers={
740+
"Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json"
741+
},
710742
)
711743

712744
raise ValueError("The version id specified could not be found.")
713745

714746
@classmethod
715747
def get_model_version_contents(cls, model, version_id):
748+
"""Get the contents of a model version.
749+
750+
Parameters
751+
----------
752+
model : str or dict
753+
The name, id, or dictionary representation of a model.
754+
version_id: str
755+
The id of a model version.
756+
757+
Returns
758+
-------
759+
list
760+
761+
"""
716762
model_version = cls.get_model_version(model, version_id)
717763
version_contents = cls.request_link(
718764
model_version,
719765
"contents",
720-
headers={"Accept": "application/vnd.sas.models.model.content"},
766+
headers={"Accept": "application/vnd.sas.collection+json"},
721767
)
722768

723-
if version_contents is None:
724-
return {}
769+
if isinstance(version_contents, RestObj):
770+
return [version_contents]
771+
725772
return version_contents
726773

727774
@classmethod
728775
def get_model_version_content_metadata(cls, model, version_id, content_id):
776+
"""Get the content metadata header information for a model version.
777+
778+
Parameters
779+
----------
780+
model : str or dict
781+
The name, id, or dictionary representation of a model.
782+
version_id: str
783+
The id of a model version.
784+
content_id: str
785+
The id of the content file.
786+
787+
Returns
788+
-------
789+
RestObj
790+
791+
"""
729792
model_version_contents = cls.get_model_version_contents(model, version_id)
730793

731-
model_version_contents_items = model_version_contents.get("items")
732-
for i, item in enumerate(model_version_contents_items):
733-
if item.get("id") == content_id:
794+
for item in model_version_contents:
795+
if item["id"] == content_id:
734796
return cls.request_link(
735797
item,
736798
"self",
737-
headers={"Accept": "application/vnd.sas.models.model.content"},
799+
headers={"Accept": "application/vnd.sas.models.model.content+json"},
738800
)
739801

740802
raise ValueError("The content id specified could not be found.")
741803

742804
@classmethod
743805
def get_model_version_content(cls, model, version_id, content_id):
806+
"""Get the specific content inside the content file for a model version.
807+
808+
Parameters
809+
----------
810+
model : str or dict
811+
The name, id, or dictionary representation of a model.
812+
version_id: str
813+
The id of a model version.
814+
content_id: str
815+
The id of the specific content file.
816+
817+
Returns
818+
-------
819+
list
820+
821+
"""
744822

745823
metadata = cls.get_model_version_content_metadata(model, version_id, content_id)
746824
version_content_file = cls.request_link(
@@ -749,6 +827,9 @@ def get_model_version_content(cls, model, version_id, content_id):
749827

750828
if version_content_file is None:
751829
raise HTTPError("Something went wrong while accessing the metadata file.")
830+
831+
if isinstance(version_content_file, RestObj):
832+
return [version_content_file]
752833
return version_content_file
753834

754835
@classmethod

0 commit comments

Comments
 (0)