Skip to content

Commit 4f32151

Browse files
mogith-pnsrikanthbachala20
authored andcommitted
Added local-runner requirements validation step (#712)
* added local-runner requirements validation step * added post check for ollama toolkit * fixed tests
1 parent 219ce66 commit 4f32151

File tree

3 files changed

+151
-45
lines changed

3 files changed

+151
-45
lines changed

clarifai/cli/model.py

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import click
66

77
from clarifai.cli.base import cli, pat_display
8-
from clarifai.utils.cli import validate_context
8+
from clarifai.utils.cli import (
9+
check_ollama_installed,
10+
check_requirements_installed,
11+
customize_ollama_model,
12+
validate_context,
13+
)
914
from clarifai.utils.constants import (
1015
DEFAULT_LOCAL_RUNNER_APP_ID,
1116
DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_CONFIG,
@@ -22,49 +27,6 @@
2227
from clarifai.utils.misc import GitHubDownloader, clone_github_repo, format_github_repo_url
2328

2429

25-
def customize_ollama_model(model_path, model_name, port, context_length):
26-
"""Customize the Ollama model name in the cloned template files.
27-
Args:
28-
model_path: Path to the cloned model directory
29-
model_name: The model name to set (e.g., 'llama3.1', 'mistral')
30-
31-
"""
32-
model_py_path = os.path.join(model_path, "1", "model.py")
33-
34-
if not os.path.exists(model_py_path):
35-
logger.warning(f"Model file {model_py_path} not found, skipping model name customization")
36-
return
37-
38-
try:
39-
# Read the model.py file
40-
with open(model_py_path, 'r') as file:
41-
content = file.read()
42-
if model_name:
43-
# Replace the default model name in the load_model method
44-
content = content.replace(
45-
'self.model = os.environ.get("OLLAMA_MODEL_NAME", \'llama3.2\')',
46-
f'self.model = os.environ.get("OLLAMA_MODEL_NAME", \'{model_name}\')',
47-
)
48-
49-
if port:
50-
# Replace the default port variable in the model.py file
51-
content = content.replace("PORT = '23333'", f"PORT = '{port}'")
52-
53-
if context_length:
54-
# Replace the default context length variable in the model.py file
55-
content = content.replace(
56-
"context_length = '8192'", f"context_length = '{context_length}'"
57-
)
58-
59-
# Write the modified content back to model.py
60-
with open(model_py_path, 'w') as file:
61-
file.write(content)
62-
63-
except Exception as e:
64-
logger.error(f"Failed to customize Ollama model name in {model_py_path}: {e}")
65-
raise
66-
67-
6830
@cli.group(
6931
['model'], context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}
7032
)
@@ -164,6 +126,11 @@ def init(
164126

165127
# --toolkit option
166128
if toolkit == 'ollama':
129+
if not check_ollama_installed():
130+
logger.error(
131+
"Ollama is not installed. Please install it from `https://ollama.com/` to use the Ollama toolkit."
132+
)
133+
raise click.Abort()
167134
github_url = DEFAULT_OLLAMA_MODEL_REPO
168135
branch = DEFAULT_OLLAMA_MODEL_REPO_BRANCH
169136

@@ -858,6 +825,18 @@ def local_runner(ctx, model_path, pool_size):
858825
ModelBuilder._save_config(config_file, config)
859826

860827
builder = ModelBuilder(model_path, download_validation_only=True)
828+
if not check_requirements_installed(model_path):
829+
logger.error(f"Requirements not installed for model at {model_path}.")
830+
raise click.Abort()
831+
832+
# Post check while running `clarifai model local-runner` we check if the toolkit is ollama
833+
if builder.config.get('toolkit', {}).get('provider') == 'ollama':
834+
if not check_ollama_installed():
835+
logger.error(
836+
"Ollama is not installed. Please install it from `https://ollama.com/` to use the Ollama toolkit."
837+
)
838+
raise click.Abort()
839+
861840
# don't mock for local runner since you need the dependencies to run the code anyways.
862841
method_signatures = builder.get_method_signatures(mocking=False)
863842

clarifai/utils/cli.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,130 @@ def validate_context_auth(pat: str, user_id: str, api_base: str = None):
220220
logger.error(f"❌ Validation failed: \n{error_msg}")
221221
logger.error("Please check your credentials and try again.")
222222
raise click.Abort() # Exit without saving the configuration
223+
224+
225+
def customize_ollama_model(model_path, model_name, port, context_length):
226+
"""Customize the Ollama model name in the cloned template files.
227+
Args:
228+
model_path: Path to the cloned model directory
229+
model_name: The model name to set (e.g., 'llama3.1', 'mistral')
230+
231+
"""
232+
model_py_path = os.path.join(model_path, "1", "model.py")
233+
234+
if not os.path.exists(model_py_path):
235+
logger.warning(f"Model file {model_py_path} not found, skipping model name customization")
236+
return
237+
238+
try:
239+
# Read the model.py file
240+
with open(model_py_path, 'r') as file:
241+
content = file.read()
242+
if model_name:
243+
# Replace the default model name in the load_model method
244+
content = content.replace(
245+
'self.model = os.environ.get("OLLAMA_MODEL_NAME", \'llama3.2\')',
246+
f'self.model = os.environ.get("OLLAMA_MODEL_NAME", \'{model_name}\')',
247+
)
248+
249+
if port:
250+
# Replace the default port variable in the model.py file
251+
content = content.replace("PORT = '23333'", f"PORT = '{port}'")
252+
253+
if context_length:
254+
# Replace the default context length variable in the model.py file
255+
content = content.replace(
256+
"context_length = '8192'", f"context_length = '{context_length}'"
257+
)
258+
259+
# Write the modified content back to model.py
260+
with open(model_py_path, 'w') as file:
261+
file.write(content)
262+
263+
except Exception as e:
264+
logger.error(f"Failed to customize Ollama model name in {model_py_path}: {e}")
265+
raise
266+
267+
268+
def check_ollama_installed():
269+
"""Check if the Ollama CLI is installed."""
270+
try:
271+
import subprocess
272+
273+
result = subprocess.run(
274+
['ollama', '--version'], capture_output=True, text=True, check=False
275+
)
276+
if result.returncode == 0:
277+
return True
278+
else:
279+
return False
280+
except FileNotFoundError:
281+
return False
282+
283+
284+
def _is_package_installed(package_name):
285+
"""Helper function to check if a single package in requirements.txt is installed."""
286+
import importlib.metadata
287+
288+
try:
289+
importlib.metadata.distribution(package_name)
290+
logger.debug(f"✅ {package_name} - installed")
291+
return True
292+
except importlib.metadata.PackageNotFoundError:
293+
logger.debug(f"❌ {package_name} - not installed")
294+
return False
295+
except Exception as e:
296+
logger.warning(f"Error checking {package_name}: {e}")
297+
return False
298+
299+
300+
def check_requirements_installed(model_path):
301+
"""Check if all dependencies in requirements.txt are installed."""
302+
import re
303+
from pathlib import Path
304+
305+
requirements_path = Path(model_path) / "requirements.txt"
306+
307+
if not requirements_path.exists():
308+
logger.warning(f"requirements.txt not found at {requirements_path}")
309+
return True
310+
311+
try:
312+
package_pattern = re.compile(r'^([a-zA-Z0-9_-]+)')
313+
314+
# Getting package name and version (for logging)
315+
requirements = [
316+
(match.group(1), pack)
317+
for line in requirements_path.read_text().splitlines()
318+
if (pack := line.strip())
319+
and not line.startswith('#')
320+
and (match := package_pattern.match(line))
321+
]
322+
323+
if not requirements:
324+
logger.info("No dependencies found in requirements.txt")
325+
return True
326+
327+
logger.info(f"Checking {len(requirements)} dependencies...")
328+
329+
missing = [
330+
full_req
331+
for package_name, full_req in requirements
332+
if not _is_package_installed(package_name)
333+
]
334+
335+
if not missing:
336+
logger.info(f"✅ All {len(requirements)} dependencies are installed!")
337+
return True
338+
339+
# Report missing packages
340+
logger.error(
341+
f"❌ {len(missing)} of {len(requirements)} required packages are missing in the current environment"
342+
)
343+
logger.error("\n".join(f" - {pkg}" for pkg in missing))
344+
logger.warning(f"To install: pip install -r {requirements_path}")
345+
return False
346+
347+
except Exception as e:
348+
logger.error(f"Failed to check requirements: {e}")
349+
return False

tests/workflow/fixtures/single_branch_with_custom_cropper_model.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ workflow:
1414
description: Custom crop model
1515
output_info:
1616
params:
17-
margin: 1.33
17+
margin: 1.3
1818
node_inputs:
1919
- node_id: detector

0 commit comments

Comments
 (0)