Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ class RegisterModelRequest(BaseModel):
persist: bool


class AddModelRequest(BaseModel):
model_type: str
model_json: Dict[str, Any]


class BuildGradioInterfaceRequest(BaseModel):
model_type: str
model_name: str
Expand Down Expand Up @@ -900,6 +905,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/models/add",
self.add_model,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/cache/models",
self.list_cached_models,
Expand Down Expand Up @@ -3123,6 +3138,30 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=None)

async def add_model(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()

body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json

# Call supervisor
supervisor_ref = await self._get_supervisor_ref()
await supervisor_ref.add_model(model_type, model_json)

except ValueError as re:
logger.error(f"ValueError in add_model API: {re}", exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(f"Unexpected error in add_model API: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

return JSONResponse(
content={"message": f"Model added successfully for type: {model_type}"}
)

async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
) -> JSONResponse:
Expand Down
237 changes: 237 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import itertools
import json
import os
import signal
import time
Expand Down Expand Up @@ -932,6 +933,242 @@ async def register_model(
else:
raise ValueError(f"Unsupported model type: {model_type}")

@log_async(logger=logger)
async def add_model(self, model_type: str, model_json: Dict[str, Any]):
"""
Add a new model by parsing the provided JSON and registering it.

Args:
model_type: Type of model (LLM, embedding, image, etc.)
model_json: JSON configuration for the model
"""
# Validate model type
supported_types = list(self._custom_register_type_to_cls.keys())

if model_type not in self._custom_register_type_to_cls:
raise ValueError(
f"Unsupported model type '{model_type}'. "
f"Supported types are: {', '.join(supported_types)}"
)

# Get the appropriate model class and register function
(
model_spec_cls,
register_fn,
unregister_fn,
generate_fn,
) = self._custom_register_type_to_cls[model_type]

# Validate required fields (only model_name is required)
required_fields = ["model_name"]
for field in required_fields:
if field not in model_json:
raise ValueError(f"Missing required field: {field}")
# Validate model name format
from ..model.utils import is_valid_model_name

model_name = model_json["model_name"]

if not is_valid_model_name(model_name):
raise ValueError(f"Invalid model name format: {model_name}")

# Convert model hub JSON format to Xinference expected format
try:
converted_model_json = self._convert_model_json_format(model_json)
except Exception as e:
raise ValueError(f"Failed to convert model JSON format: {str(e)}")

# Parse the JSON into the appropriate model spec
try:
model_spec = model_spec_cls.parse_obj(converted_model_json)
except Exception as e:
raise ValueError(f"Invalid model JSON format: {str(e)}")

# Check if model already exists
try:
existing_model = await self.get_model_registration(
model_type, model_spec.model_name
)

if existing_model is not None:
raise ValueError(
f"Model '{model_spec.model_name}' already exists for type '{model_type}'. "
f"Please choose a different model name or remove the existing model first."
)

except ValueError as e:
if "not found" in str(e):
# Model doesn't exist, we can proceed
pass
else:
# Re-raise validation errors
raise e
except Exception as ex:
raise ValueError(f"Failed to validate model registration: {str(ex)}")

# Register the model (persist=True for adding models)
try:
register_fn(model_spec, persist=True)

# Record model version
version_info = generate_fn(model_spec)
await self._cache_tracker_ref.record_model_version(
version_info, self.address
)

# Sync to workers if not local deployment
is_local = self.is_local_deployment()
if not is_local:
# Convert back to JSON string for sync compatibility
model_json_str = json.dumps(converted_model_json)
await self._sync_register_model(
model_type, model_json_str, True, model_spec.model_name
)

logger.info(
f"Successfully added model '{model_spec.model_name}' (type: {model_type})"
)

except ValueError as e:
# Validation errors - don't need cleanup as model wasn't registered
raise e
except Exception as e:
# Unexpected errors - attempt cleanup
try:
unregister_fn(model_spec.model_name, raise_error=False)
except Exception as cleanup_error:
logger.warning(f"Cleanup failed: {cleanup_error}")
raise ValueError(
f"Failed to register model '{model_spec.model_name}': {str(e)}"
)

def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert model hub JSON format to Xinference expected format.

The input format uses nested 'model_src' structure, but Xinference expects
flattened fields at the spec level.

Also handles cases where model_specs field is missing by providing a default.
"""
# If model_specs is missing, provide a default minimal spec
if "model_specs" not in model_json or not model_json["model_specs"]:
# Create a minimal default spec
return {
**model_json,
"model_specs": [
{
"model_format": "pytorch",
"model_size_in_billions": None,
"quantization": "none",
}
],
}

# Check if conversion is needed (detect model_src structure)
needs_conversion = False
for spec in model_json["model_specs"]:
if "model_src" in spec:
needs_conversion = True
break

if not needs_conversion:
return model_json

converted = model_json.copy()
converted_specs = []

for spec in model_json["model_specs"]:
model_format = spec["model_format"]
model_size = spec["model_size_in_billions"]

if "model_src" not in spec:
# No model_src, keep spec as is but ensure required fields
converted_spec = spec.copy()
if "quantization" not in converted_spec:
converted_spec["quantization"] = "none" # Default
converted_specs.append(converted_spec)
continue

model_src = spec["model_src"]

# Handle different model sources
if "huggingface" in model_src:
hf_info = model_src["huggingface"]
quantizations = hf_info.get("quantizations", ["none"])

# Create separate specs for each quantization
for quant in quantizations:
converted_spec = {
"model_format": model_format,
"model_size_in_billions": model_size,
"quantization": quant,
"model_hub": "huggingface",
}

# Add common fields
if "model_id" in hf_info:
converted_spec["model_id"] = hf_info["model_id"]
if "model_revision" in hf_info:
converted_spec["model_revision"] = hf_info["model_revision"]

# Format-specific fields
if model_format == "ggufv2":
if "model_id" in hf_info:
converted_spec["model_id"] = hf_info["model_id"]
if "model_file_name_template" in hf_info:
converted_spec["model_file_name_template"] = hf_info[
"model_file_name_template"
]
else:
# Default template
model_name = model_json["model_name"]
converted_spec["model_file_name_template"] = (
f"{model_name}-{{quantization}}.gguf"
)
elif model_format in ["pytorch", "mlx"]:
if "model_id" in hf_info:
converted_spec["model_id"] = hf_info["model_id"]
if "model_revision" in hf_info:
converted_spec["model_revision"] = hf_info["model_revision"]

converted_specs.append(converted_spec)

elif "modelscope" in model_src:
# Handle ModelScope similarly
ms_info = model_src["modelscope"]
quantizations = ms_info.get("quantizations", ["none"])

for quant in quantizations:
converted_spec = {
"model_format": model_format,
"model_size_in_billions": model_size,
"quantization": quant,
"model_hub": "modelscope",
}

if "model_id" in ms_info:
converted_spec["model_id"] = ms_info["model_id"]
if "model_revision" in ms_info:
converted_spec["model_revision"] = ms_info["model_revision"]

converted_specs.append(converted_spec)

else:
# Unknown model source, skip or handle as error
logger.warning(
f"Unknown model source in spec: {list(model_src.keys())}"
)
# Keep original spec but add required fields
converted_spec = spec.copy()
if "quantization" not in converted_spec:
converted_spec["quantization"] = "none"
converted_specs.append(converted_spec)

converted["model_specs"] = converted_specs

return converted

async def _sync_register_model(
self, model_type: str, model: str, persist: bool, model_name: str
):
Expand Down
25 changes: 24 additions & 1 deletion xinference/ui/web/ui/src/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,30 @@
"featured": "featured",
"all": "all",
"cancelledSuccessfully": "Cancelled Successfully!",
"mustBeUnique": "{{key}} must be unique"
"mustBeUnique": "{{key}} must be unique",
"addModel": "Add Model",
"addModelDialog": {
"introPrefix": "To add a model, please use",
"platformLinkText": "Model Management Platform",
"introSuffix": " and paste the model's URL",
"example": "Example: The URL for {{modelName}} on the platform is {{modelUrl}}",
"urlLabel": "URL"
},
"loginDialog": {
"title": "No permission to download this model. Please log in and try again.",
"usernameOrEmail": "Username or Email",
"password": "Password",
"login": "Login"
},
"error": {
"cannotExtractModelId": "Unable to extract model_id from URL. Please check your input.",
"downloadFailed": "Download failed: {{status}} {{text}}",
"requestFailed": "Request failed",
"loginFailedText": "Login failed: {{status}} {{text}}",
"noTokenAfterLogin": "Login succeeded but no token was returned",
"modelPrivate": "This model is private and requires download permission.",
"noPermissionAfterLogin": "The logged-in account does not have permission to download this model. Please contact the administrator or use a different account."
}
},

"runningModels": {
Expand Down
25 changes: 24 additions & 1 deletion xinference/ui/web/ui/src/locales/ja.json
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,30 @@
"featured": "おすすめとお気に入り",
"all": "すべて",
"cancelledSuccessfully": "正常にキャンセルされました!",
"mustBeUnique": "{{key}} は一意でなければなりません"
"mustBeUnique": "{{key}} は一意でなければなりません",
"addModel": "モデルを追加",
"addModelDialog": {
"introPrefix": "モデルを追加するには",
"platformLinkText": "モデル管理プラットフォーム",
"introSuffix": "に基づき、対応するURLを入力してください",
"example": "例:{{modelName}} のモデル管理プラットフォーム上のURLは {{modelUrl}} です",
"urlLabel": "URL"
},
"loginDialog": {
"title": "このモデルをダウンロードする権限がありません。ログイン後に再度お試しください",
"usernameOrEmail": "ユーザー名またはメールアドレス",
"password": "パスワード",
"login": "ログイン"
},
"error": {
"cannotExtractModelId": "URLから model_id を抽出できません。入力内容を確認してください",
"downloadFailed": "ダウンロード失敗: {{status}} {{text}}",
"requestFailed": "リクエスト失敗",
"loginFailedText": "ログイン失敗: {{status}} {{text}}",
"noTokenAfterLogin": "ログインは成功しましたが、トークンを取得できませんでした",
"modelPrivate": "このモデルは非公開であり、ダウンロード権限が必要です。",
"noPermissionAfterLogin": "このアカウントにはモデルをダウンロードする権限がありません。管理者に連絡するか、別のアカウントを使用してください。"
}
},

"runningModels": {
Expand Down
Loading
Loading