Skip to content

Commit 3a8016a

Browse files
JanusChoipre-commit-ci[bot]dlqqq
authored
Fix JSON serialization error in Ollama models (#1129)
* fixing #1128 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify impl and verify dict() requires no args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David L. Qiu <david@qiu.dev>
1 parent 2c196d9 commit 3a8016a

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
1+
import inspect
2+
import json
3+
14
from langchain_core.callbacks import BaseCallbackHandler
25
from langchain_core.outputs import LLMResult
36

47

8+
def requires_no_arguments(func):
9+
sig = inspect.signature(func)
10+
for param in sig.parameters.values():
11+
if param.default is param.empty and param.kind in (
12+
param.POSITIONAL_ONLY,
13+
param.POSITIONAL_OR_KEYWORD,
14+
param.KEYWORD_ONLY,
15+
):
16+
return False
17+
return True
18+
19+
20+
def convert_to_serializable(obj):
21+
"""Convert an object to a JSON serializable format"""
22+
if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict):
23+
return obj.dict()
24+
if hasattr(obj, "__dict__"):
25+
return obj.__dict__
26+
return str(obj)
27+
28+
529
class MetadataCallbackHandler(BaseCallbackHandler):
630
"""
731
When passed as a callback handler, this stores the LLMResult's
@@ -23,4 +47,9 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
2347
if not (len(response.generations) and len(response.generations[0])):
2448
return
2549

26-
self.jai_metadata = response.generations[0][0].generation_info or {}
50+
metadata = response.generations[0][0].generation_info or {}
51+
52+
# Convert any non-serializable objects in metadata
53+
self.jai_metadata = json.loads(
54+
json.dumps(metadata, default=convert_to_serializable)
55+
)

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, List, Literal, Optional, Union
23

34
from jupyter_ai_magics import Persona
@@ -128,6 +129,15 @@ class AgentStreamChunkMessage(BaseModel):
128129
on `BaseAgentMessage.metadata` for information.
129130
"""
130131

132+
@validator("metadata")
133+
def validate_metadata(cls, v):
134+
"""Ensure metadata values are JSON serializable"""
135+
try:
136+
json.dumps(v)
137+
return v
138+
except TypeError as e:
139+
raise ValueError(f"Metadata must be JSON serializable: {str(e)}")
140+
131141

132142
class HumanChatMessage(BaseModel):
133143
type: Literal["human"] = "human"

0 commit comments

Comments
 (0)