Skip to content

Commit 2709bb6

Browse files
committed
refactor: persist hazard information in message.additional_kwargs, instead of as a standalone message
1 parent b2165c6 commit 2709bb6

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

api/chatbot/agent.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain_core.prompts import ChatPromptTemplate
88
from langgraph.graph import END, START, MessagesState, StateGraph
99

10-
from chatbot.safety import create_hazard_classifier
10+
from chatbot.safety import create_hazard_classifier, hazard_categories
1111

1212
if TYPE_CHECKING:
1313
from langchain_core.language_models import BaseChatModel
@@ -37,15 +37,16 @@ def create_agent(
3737

3838
async def input_guard(state: MessagesState) -> MessagesState:
3939
if hazard_classifier is not None:
40+
last_message = state["messages"][-1]
4041
flag, category = await hazard_classifier.ainvoke(
41-
input={"messages": state["messages"][-1:]}
42+
input={"messages": [last_message]}
4243
)
4344
if flag == "unsafe" and category is not None:
44-
content = f"""The user input may contain inproper content related to:
45-
{category}
46-
47-
Please respond with care and professionalism. Avoid engaging with harmful or unethical content. Instead, guide the user towards more constructive and respectful communication."""
48-
return {"messages": [SystemMessage(content=content)]}
45+
# patch the hazard category to the last message
46+
last_message.additional_kwargs = last_message.additional_kwargs | {
47+
"hazard": category
48+
}
49+
return {"messages": [last_message]}
4950
return {"messages": []}
5051

5152
async def run_output_guard(state: MessagesState) -> MessagesState:
@@ -78,12 +79,20 @@ async def chatbot(state: MessagesState) -> MessagesState:
7879

7980
bound = prompt | chat_model
8081

81-
windowed_messages = trim_messages(
82+
windowed_messages: list[BaseMessage] = trim_messages(
8283
state["messages"],
8384
token_counter=token_counter,
8485
max_tokens=max_tokens,
8586
start_on="human", # This means that the first message should be from the user after trimming.
8687
)
88+
if hazard := windowed_messages[-1].additional_kwargs.get("hazard"):
89+
hint_message = SystemMessage(
90+
content=f"""The user input may contain inproper content related to:
91+
{hazard_categories.get(hazard)}
92+
93+
Please respond with care and professionalism. Avoid engaging with harmful or unethical content. Instead, guide the user towards more constructive and respectful communication."""
94+
)
95+
windowed_messages.append(hint_message)
8796

8897
messages = await bound.ainvoke(
8998
{
@@ -93,9 +102,7 @@ async def chatbot(state: MessagesState) -> MessagesState:
93102
).date(), # TODO: get the current date from the user?
94103
}
95104
)
96-
return {
97-
"messages": [messages],
98-
}
105+
return {"messages": [messages]}
99106

100107
builder = StateGraph(MessagesState)
101108
builder.add_node(input_guard)

api/chatbot/safety.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def parse(self, text: str) -> tuple[str, str | None]:
4646

4747
flag, category = text.split("\n", 1)
4848
if flag.lower() == "unsafe":
49-
return "unsafe", hazard_categories.get(category)
49+
return "unsafe", category
5050
return "unknown", None
5151

5252

0 commit comments

Comments
 (0)