|
| 1 | +from typing import cast |
| 2 | + |
| 3 | +from flux0_core.agent_runners.api import AgentRunner, Deps, agent_runner |
| 4 | +from flux0_core.agent_runners.context import Context |
| 5 | +from flux0_core.sessions import ( |
| 6 | + MessageEventData, |
| 7 | + StatusEventData, |
| 8 | +) |
| 9 | +from flux0_stream.frameworks.langchain import RunContext, filter_and_map_events, handle_event |
| 10 | +from langchain.chat_models import init_chat_model |
| 11 | +from langchain_core.messages import HumanMessage, SystemMessage |
| 12 | + |
| 13 | + |
| 14 | +async def read_user_input(deps: Deps, context: Context) -> str: |
| 15 | + """Extract user input from the last message event.""" |
| 16 | + # read session events and expect the last event to be the user input |
| 17 | + events = await deps.list_session_events(context.session_id) |
| 18 | + last_event = events[-1] |
| 19 | + if last_event.type != "message": |
| 20 | + raise ValueError(f"Expected last event to be a message, got {last_event.type}") |
| 21 | + user_event_data = cast(MessageEventData, last_event.data) |
| 22 | + for part in user_event_data["parts"]: |
| 23 | + if part["type"] == "content": |
| 24 | + user_input = part["content"] |
| 25 | + break |
| 26 | + if not user_input: |
| 27 | + raise ValueError("No TextPart found in user event data") |
| 28 | + |
| 29 | + return str(user_input) |
| 30 | + |
| 31 | + |
| 32 | +@agent_runner("langchain_simple") |
| 33 | +class LangChainAgentRunner(AgentRunner): |
| 34 | + async def run(self, context: Context, deps: Deps) -> bool: |
| 35 | + # Read the agent model from db |
| 36 | + agent = await deps.read_agent(context.agent_id) |
| 37 | + if not agent: |
| 38 | + deps.logger.error(f"Agent with ID {context.agent_id} not found") |
| 39 | + return False |
| 40 | + |
| 41 | + user_input = await read_user_input(deps, context) |
| 42 | + |
| 43 | + # initialize and run the chat model via LangChain in streaming mode |
| 44 | + model = init_chat_model("gpt-4.1-nano", model_provider="openai") |
| 45 | + messages = [ |
| 46 | + SystemMessage("Translate the following from English into Italian"), |
| 47 | + HumanMessage(user_input), |
| 48 | + ] |
| 49 | + try: |
| 50 | + model_events = model.astream_events( |
| 51 | + messages, |
| 52 | + stream=True, |
| 53 | + version="v2", |
| 54 | + ) |
| 55 | + |
| 56 | + run_ctx: RunContext = RunContext(last_known_event_offset=0) |
| 57 | + # iterate over the model events and stream them to the client |
| 58 | + async for e in filter_and_map_events(model_events, deps.logger): |
| 59 | + await handle_event( |
| 60 | + agent, |
| 61 | + deps.correlator.correlation_id, |
| 62 | + e, |
| 63 | + deps.event_emitter, |
| 64 | + deps.logger, |
| 65 | + run_ctx, |
| 66 | + ) |
| 67 | + except Exception as e: |
| 68 | + await deps.event_emitter.enqueue_status_event( |
| 69 | + correlation_id=deps.correlator.correlation_id, |
| 70 | + data=StatusEventData(type="status", status="error", data=str(e)), |
| 71 | + ) |
| 72 | + return False |
| 73 | + finally: |
| 74 | + # mark session stream completion |
| 75 | + await deps.event_emitter.enqueue_status_event( |
| 76 | + correlation_id=deps.correlator.correlation_id, |
| 77 | + data=StatusEventData(type="status", status="completed", acknowledged_offset=0), |
| 78 | + ) |
| 79 | + |
| 80 | + return True |
0 commit comments