diff --git a/.changeset/when-lizards-fly.md b/.changeset/when-lizards-fly.md new file mode 100644 index 000000000..e467f3232 --- /dev/null +++ b/.changeset/when-lizards-fly.md @@ -0,0 +1,5 @@ +--- +'@sap-ai-sdk/orchestration': minor +--- + +[Improvement] Add utility functions `getContent()`, `getRefusal()`, `getAllMessages()`, `getAssistantMessage()`, and `getResponse()` to stream response. \ No newline at end of file diff --git a/packages/orchestration/src/orchestration-client.test.ts b/packages/orchestration/src/orchestration-client.test.ts index 354e22ca1..dbfef3311 100644 --- a/packages/orchestration/src/orchestration-client.test.ts +++ b/packages/orchestration/src/orchestration-client.test.ts @@ -842,6 +842,7 @@ describe('orchestration service client', () => { "name": "add", }, "id": "call_HPgxxSmD2ctYfcJ3gp1JBc7i", + "index": 0, "type": "function", }, { @@ -850,9 +851,74 @@ describe('orchestration service client', () => { "name": "multiply", }, "id": "call_PExve0Dd9hxD8hOk4Uhr1yhO", + "index": 1, "type": "function", }, ] `); }); + describe('OrchestrationClient Stream Error Handling', () => { + it('should abort controller and re-throw error when network request fails', async () => { + const config: OrchestrationModuleConfig = { + llm: { + model_name: 'gpt-4o', + model_params: {} + }, + templating: { + template: [ + { + role: 'user', + content: 'Test prompt' + } + ] + } + }; + + const controller = new AbortController(); + + // Mock network failure + mockInference( + { + data: constructCompletionPostRequest(config, undefined, true) + }, + { + status: 500, + data: { error: 'Internal Server Error' } + }, + { + url: 'inference/deployments/1234/completion' + } + ); + + const client = new OrchestrationClient(config); + + await expect(client.stream(undefined, controller)).rejects.toThrow(); + }); + + it('should throw error when stream is called with already aborted controller', async () => { + const config: OrchestrationModuleConfig = { + llm: { + model_name: 'gpt-4o', + model_params: {} + }, + templating: { + template: [ + { + role: 'user', + content: 'Test prompt' + } + ] + } + }; + + const controller = new AbortController(); + + // Abort immediately + controller.abort(); + + const client = new OrchestrationClient(config); + + await expect(client.stream(undefined, controller)).rejects.toThrow(); + }); + }); }); diff --git a/packages/orchestration/src/orchestration-client.ts b/packages/orchestration/src/orchestration-client.ts index 4bdfc3ca3..a1662ae9f 100644 --- a/packages/orchestration/src/orchestration-client.ts +++ b/packages/orchestration/src/orchestration-client.ts @@ -73,21 +73,26 @@ export class OrchestrationClient { options?: StreamOptions, requestConfig?: CustomRequestConfig ): Promise> { - if (typeof this.config === 'string' && options) { - logger.warn( - 'Stream options are not supported when using a JSON module config.' + try { + if (typeof this.config === 'string' && options) { + logger.warn( + 'Stream options are not supported when using a JSON module config.' + ); + } + + return await this.createStreamResponse( + { + prompt, + requestConfig, + stream: true, + streamOptions: options + }, + controller ); + } catch (error) { + controller.abort(); + throw error; } - - return this.createStreamResponse( - { - prompt, - requestConfig, - stream: true, - streamOptions: options - }, - controller - ); } private async executeRequest(options: RequestOptions): Promise { @@ -143,9 +148,11 @@ export class OrchestrationClient { const stream = OrchestrationStream._create(streamResponse, controller); response.stream = stream ._pipe(OrchestrationStream._processChunk) - ._pipe(OrchestrationStream._processToolCalls, response) - ._pipe(OrchestrationStream._processFinishReason, response) - ._pipe(OrchestrationStream._processTokenUsage, response); + ._pipe( + OrchestrationStream._processOrchestrationStreamChunkResponse, + response + ) + ._pipe(OrchestrationStream._processStreamEnd, response); return response; } diff --git a/packages/orchestration/src/orchestration-stream-response.ts b/packages/orchestration/src/orchestration-stream-response.ts index f04201789..1f3612ded 100644 --- a/packages/orchestration/src/orchestration-stream-response.ts +++ b/packages/orchestration/src/orchestration-stream-response.ts @@ -1,5 +1,8 @@ -import type { ToolCallAccumulator } from './util/index.js'; import type { + AssistantChatMessage, + ChatMessage, + ChatMessages, + CompletionPostResponse, MessageToolCalls, TokenUsage } from './client/api/schema/index.js'; @@ -9,31 +12,19 @@ import type { OrchestrationStream } from './orchestration-stream.js'; * Orchestration stream response. */ export class OrchestrationStreamResponse { - private _usage: TokenUsage | undefined; - /** - * Finish reasons for all choices. - */ - private _finishReasons: Map = new Map(); - private _toolCallsAccumulators: Map< - number, - Map - > = new Map(); + public _openStream = true; + public _data: Partial = {}; private _stream: OrchestrationStream | undefined; - private _toolCalls: Map = new Map(); /** * Gets the token usage for the response. * @returns The token usage for the response. */ public getTokenUsage(): TokenUsage | undefined { - return this._usage; - } - - /** - * @internal - */ - _setTokenUsage(usage: TokenUsage): void { - this._usage = usage; + if (this.isStreamOpen()) { + return; + } + return this._data.orchestration_result?.usage; } /** @@ -42,44 +33,86 @@ export class OrchestrationStreamResponse { * @returns The finish reason for the specified choice index. */ public getFinishReason(choiceIndex = 0): string | undefined { - return this._finishReasons.get(choiceIndex); + if (this.isStreamOpen()) { + return; + } + return this.findChoiceByIndex(choiceIndex)?.finish_reason; } /** - * @internal + * Parses the orchestration response and returns the content. + * If the response was filtered, an error is thrown. + * @param choiceIndex - The index of the choice to parse. + * @returns The message content. */ - _getFinishReasons(): Map { - return this._finishReasons; + public getContent(choiceIndex = 0): string | undefined { + if (this.isStreamOpen()) { + return; + } + const choice = this.findChoiceByIndex(choiceIndex); + return choice?.message?.content; } /** - * @internal + * Parses the orchestration response and returns the tool calls generated by the model. + * @param choiceIndex - The index of the choice to parse. + * @returns The message tool calls. */ - _setFinishReasons(finishReasons: Map): void { - this._finishReasons = finishReasons; + public getToolCalls(choiceIndex = 0): MessageToolCalls | undefined { + if (this.isStreamOpen()) { + return; + } + const choice = this.findChoiceByIndex(choiceIndex); + return choice?.message?.tool_calls; } /** - * Gets the tool calls for a specific choice index. - * @param choiceIndex - The index of the choice to get the tool calls for. - * @returns The tool calls for the specified choice index. + * Parses the orchestration response and returns the refusal message generated by the model. + * @param choiceIndex - The index of the choice to parse. + * @returns The refusal string. */ - public getToolCalls(choiceIndex = 0): MessageToolCalls | undefined { - return this._toolCalls.get(choiceIndex); + public getRefusal(choiceIndex = 0): string | undefined { + if (this.isStreamOpen()) { + return; + } + const choice = this.findChoiceByIndex(choiceIndex); + return choice?.message?.refusal; } /** - * @internal + * Messages that can be used for subsequent prompts as message history. + * @param choiceIndex - The index of the choice to parse. + * @returns A list of all messages. */ - _setToolCalls(choiceIndex: number, toolCalls: MessageToolCalls): void { - this._toolCalls.set(choiceIndex, toolCalls); + public getAllMessages(choiceIndex = 0): ChatMessages | undefined { + if (this.isStreamOpen()) { + return; + } + const messages: ChatMessage[] = this._data.module_results?.templating ?? []; + const content = this.findChoiceByIndex(choiceIndex)?.message; + return content ? [...messages, content] : messages; } /** - * @internal + * Gets the assistant message from the response. + * @param choiceIndex - The index of the choice to use (default is 0). + * @returns The assistant message. */ - _getToolCallsAccumulators(): Map> { - return this._toolCallsAccumulators; + + public getAssistantMessage( + choiceIndex = 0 + ): AssistantChatMessage | undefined { + if (this.isStreamOpen()) { + return; + } + return this.findChoiceByIndex(choiceIndex)?.message; + } + + public getResponse(): CompletionPostResponse | undefined { + if (this.isStreamOpen()) { + return; + } + return this._data as CompletionPostResponse; } get stream(): OrchestrationStream { @@ -89,6 +122,23 @@ export class OrchestrationStreamResponse { return this._stream; } + private getChoices() { + return this._data.orchestration_result?.choices ?? []; + } + + private findChoiceByIndex(index: number) { + return this.getChoices().find((c: { index: number }) => c.index === index); + } + + private isStreamOpen(): boolean { + if (this._openStream) { + throw Error( + 'The stream is still open, the requested data is not available yet. Please wait until the stream is closed.' + ); + } + return this._openStream; + } + /** * @internal */ diff --git a/packages/orchestration/src/orchestration-stream.test.ts b/packages/orchestration/src/orchestration-stream.test.ts index 870877dac..6bf4b3372 100644 --- a/packages/orchestration/src/orchestration-stream.test.ts +++ b/packages/orchestration/src/orchestration-stream.test.ts @@ -53,16 +53,20 @@ describe('Orchestration chat completion stream', () => { it('should process the finish reasons', async () => { const logger = createLogger({ package: 'orchestration', - messageContext: 'orchestration-chat-completion-stream' + messageContext: 'stream-util' }); const debugSpy = jest.spyOn(logger, 'debug'); const asyncGeneratorChunk = OrchestrationStream._processChunk( originalChatCompletionStream ); - const asyncGeneratorFinishReason = OrchestrationStream._processFinishReason( - new OrchestrationStream(() => asyncGeneratorChunk, new AbortController()), - new OrchestrationStreamResponse() - ); + const asyncGeneratorFinishReason = + OrchestrationStream._processOrchestrationStreamChunkResponse( + new OrchestrationStream( + () => asyncGeneratorChunk, + new AbortController() + ), + new OrchestrationStreamResponse() + ); for await (const chunk of asyncGeneratorFinishReason) { expect(chunk).toBeDefined(); @@ -73,16 +77,20 @@ describe('Orchestration chat completion stream', () => { it('should process the token usage', async () => { const logger = createLogger({ package: 'orchestration', - messageContext: 'orchestration-chat-completion-stream' + messageContext: 'stream-util' }); const debugSpy = jest.spyOn(logger, 'debug'); const asyncGeneratorChunk = OrchestrationStream._processChunk( originalChatCompletionStream ); - const asyncGeneratorTokenUsage = OrchestrationStream._processTokenUsage( - new OrchestrationStream(() => asyncGeneratorChunk, new AbortController()), - new OrchestrationStreamResponse() - ); + const asyncGeneratorTokenUsage = + OrchestrationStream._processOrchestrationStreamChunkResponse( + new OrchestrationStream( + () => asyncGeneratorChunk, + new AbortController() + ), + new OrchestrationStreamResponse() + ); for await (const chunk of asyncGeneratorTokenUsage) { expect(chunk).toBeDefined(); diff --git a/packages/orchestration/src/orchestration-stream.ts b/packages/orchestration/src/orchestration-stream.ts index fd278451c..4c138018e 100644 --- a/packages/orchestration/src/orchestration-stream.ts +++ b/packages/orchestration/src/orchestration-stream.ts @@ -1,23 +1,10 @@ -import { createLogger } from '@sap-cloud-sdk/util'; import { SseStream } from '@sap-ai-sdk/core'; import { OrchestrationStreamChunkResponse } from './orchestration-stream-chunk-response.js'; -import { - isMessageToolCall, - mergeToolCallChunk, - type ToolCallAccumulator -} from './internal.js'; -import type { - CompletionPostResponseStreaming, - MessageToolCalls -} from './client/api/schema/index.js'; +import { mergeStreamResponse } from './util/index.js'; +import type { CompletionPostResponseStreaming } from './client/api/schema/index.js'; import type { HttpResponse } from '@sap-cloud-sdk/http-client'; import type { OrchestrationStreamResponse } from './orchestration-stream-response.js'; -const logger = createLogger({ - package: 'orchestration', - messageContext: 'orchestration-chat-completion-stream' -}); - /** * Orchestration stream containing post-processing functions. */ @@ -53,122 +40,33 @@ export class OrchestrationStream extends SseStream { } } - /** - * @internal - */ - static async *_processToolCalls( + static async *_processOrchestrationStreamChunkResponse( stream: OrchestrationStream, response?: OrchestrationStreamResponse ): AsyncGenerator { if (!response) { - throw new Error('Response is required to process tool calls.'); + throw new Error( + 'Response is required to process completion post response streaming.' + ); } for await (const chunk of stream) { - chunk.data.orchestration_result?.choices.forEach(choice => { - const choiceIndex = choice.index; - const toolCallsChunks = chunk.getDeltaToolCalls(choiceIndex); - if (toolCallsChunks) { - let toolCallAccumulators = response - ._getToolCallsAccumulators() - .get(choiceIndex); - if (!toolCallAccumulators) { - toolCallAccumulators = new Map(); - response - ._getToolCallsAccumulators() - .set(choiceIndex, toolCallAccumulators); - } - toolCallsChunks.map(toolCallChunk => { - const toolCallId = toolCallChunk.index; - const toolCallAccumulator = mergeToolCallChunk( - toolCallChunk, - toolCallAccumulators.get(toolCallId) - ); - toolCallAccumulators.set(toolCallId, toolCallAccumulator); - }); - } - }); + mergeStreamResponse(response, chunk.data); yield chunk; } - - for (const [ - choiceIndex, - toolCallsAccumulators - ] of response._getToolCallsAccumulators()) { - const toolCalls: MessageToolCalls = []; - for (const [id, acc] of toolCallsAccumulators.entries()) { - if (isMessageToolCall(acc)) { - toolCalls.push(acc); - } else { - logger.error( - `Error while parsing tool calls for choice index ${choiceIndex}: Tool call with id ${id} was incomplete.` - ); - } - } - response._setToolCalls(choiceIndex, toolCalls); - } } - /** - * @internal - */ - static async *_processFinishReason( + static async *_processStreamEnd( stream: OrchestrationStream, response?: OrchestrationStreamResponse ): AsyncGenerator { if (!response) { - throw new Error('Response is required to process finish reasons.'); + throw new Error('Response is required to process stream end.'); } for await (const chunk of stream) { - chunk.data.orchestration_result?.choices.forEach(choice => { - const choiceIndex = choice.index; - const finishReason = chunk.getFinishReason(choiceIndex); - if (finishReason) { - response._getFinishReasons().set(choiceIndex, finishReason); - switch (finishReason) { - case 'content_filter': - logger.error( - `Choice ${choiceIndex}: Stream finished with content filter hit.` - ); - break; - case 'length': - logger.error( - `Choice ${choiceIndex}: Stream finished with token length exceeded.` - ); - break; - case 'stop': - case 'tool_calls': - case 'function_call': - logger.debug(`Choice ${choiceIndex}: Stream finished.`); - break; - default: - logger.error( - `Choice ${choiceIndex}: Stream finished with unknown reason '${finishReason}'.` - ); - } - } - }); yield chunk; } - } - /** - * @internal - */ - static async *_processTokenUsage( - stream: OrchestrationStream, - response?: OrchestrationStreamResponse - ): AsyncGenerator { - if (!response) { - throw new Error('Response is required to process token usage.'); - } - for await (const chunk of stream) { - const usage = chunk.getTokenUsage(); - if (usage) { - response._setTokenUsage(usage); - logger.debug(`Token usage: ${JSON.stringify(usage)}`); - } - yield chunk; - } + response._openStream = false; } /** diff --git a/packages/orchestration/src/util/index.ts b/packages/orchestration/src/util/index.ts index e605e922f..0e5865093 100644 --- a/packages/orchestration/src/util/index.ts +++ b/packages/orchestration/src/util/index.ts @@ -3,4 +3,4 @@ export * from './grounding.js'; export * from './module-config.js'; export * from './masking.js'; export * from './translation.js'; -export * from './tool-calls.js'; +export * from './stream.js'; diff --git a/packages/orchestration/src/util/stream.test.ts b/packages/orchestration/src/util/stream.test.ts new file mode 100644 index 000000000..555b4824a --- /dev/null +++ b/packages/orchestration/src/util/stream.test.ts @@ -0,0 +1,544 @@ +import { OrchestrationStreamResponse } from '../index.js'; +import { mergeStreamResponse } from './stream.js'; +import type { + CompletionPostResponseStreaming, + OrchestrationStreamChunkResponse +} from '../index.js'; + +const llmBase = { + id: 'orchestration-id-1', + object: 'chat.completion.chunk', + created: 1752575616, + model: 'gpt-4o-2024-08-06', + system_fingerprint: 'fp_ee1d74bde0', + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } +}; + +describe('stream-util', () => { + describe('mergeStreamResponse', () => { + it('merges basic stream response properties', () => { + const response = + new OrchestrationStreamResponse(); + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + module_results: {}, + orchestration_result: { + ...llmBase, + choices: [], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 } + } + }; + + mergeStreamResponse(response, chunk); + + expect(response._data.request_id).toBe('test-request-123'); + expect(response._data.orchestration_result).toBeDefined(); + expect(response._data.orchestration_result?.usage).toEqual({ + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + }); + }); + + it('merges module results with llm module', () => { + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + module_results: { + llm: { + ...llmBase, + usage: { + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30 + }, + choices: [ + { + index: 0, + delta: { content: 'Hello' } + } + ] + } + } + }; + + const response = + new OrchestrationStreamResponse(); + + mergeStreamResponse(response, chunk); + + expect(response._data.module_results?.llm?.usage).toEqual({ + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30 + }); + expect(response._data.module_results?.llm?.choices).toEqual([ + { + index: 0, + message: { + role: 'assistant', + content: 'Hello' + }, + finish_reason: '' + } + ]); + }); + + it('merges output_unmasking module results', () => { + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + module_results: { + output_unmasking: [ + { + index: 0, + delta: { content: 'Unmasked content' } + } + ] + } + }; + + const response = + new OrchestrationStreamResponse(); + + mergeStreamResponse(response, chunk); + + expect(response._data.module_results?.output_unmasking).toHaveLength(1); + expect( + response._data.module_results?.output_unmasking?.[0].message.content + ).toBe('Unmasked content'); + }); + }); + + describe('token usage merging', () => { + it('merges token usage with existing values', () => { + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + usage: { prompt_tokens: 15, completion_tokens: 8, total_tokens: 23 }, + choices: [] + } + }; + + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 }, + choices: [] + } + }; + + mergeStreamResponse(response, chunk); + + expect(response._data.orchestration_result?.usage).toEqual({ + prompt_tokens: 15, + completion_tokens: 8, + total_tokens: 23 + }); + }); + + it('handles missing token usage gracefully', () => { + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [] + } + }; + + delete chunk.orchestration_result?.usage; + + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + choices: [] + } + }; + + mergeStreamResponse(response, chunk); + + expect(response._data.orchestration_result?.usage).toEqual({ + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15 + }); + }); + }); + + describe('choice merging', () => { + it('merges content from multiple chunks', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'Hello' + }, + finish_reason: '', + logprobs: { + content: [], + refusal: [] + } + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { content: ' World' }, + finish_reason: 'stop' + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0].message.content + ).toBe('Hello World'); + expect( + response._data.orchestration_result?.choices[0].finish_reason + ).toBe('stop'); + }); + + it('adds new choice when index does not exist', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'First choice' + }, + finish_reason: 'stop' + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 1, + delta: { content: 'Second choice' }, + finish_reason: 'stop' + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect(response._data.orchestration_result?.choices).toHaveLength(2); + expect( + response._data.orchestration_result?.choices[1].message.content + ).toBe('Second choice'); + }); + + it('handles finish reasons correctly', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'Test' + }, + finish_reason: '' + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { content: '' }, + finish_reason: 'content_filter' + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0].finish_reason + ).toBe('content_filter'); + }); + }); + + describe('tool call merging', () => { + it('merges tool call arguments from multiple chunks', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: '', + tool_calls: [ + { + index: 0, + id: 'tool-call-1', + type: 'function', + function: { + name: 'test_function', + arguments: '{"param1":' + } + } + ] + }, + finish_reason: '' + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { + content: '', + tool_calls: [ + { + index: 0, + id: 'tool-call-1', + type: 'function', + function: { + name: 'test_function', + arguments: '"value1"}' + } + } + ] + } + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0].message.tool_calls?.[0] + .function.arguments + ).toBe('{"param1":"value1"}'); + }); + + it('adds new tool call when index does not exist', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + module_results: {}, + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: '', + tool_calls: [ + { + index: 0, + id: 'tool-call-1', + type: 'function', + function: { + name: 'first_function', + arguments: '{"param1":"value1"}' + } + } + ] + }, + finish_reason: '' + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { + content: '', + tool_calls: [ + { + index: 1, + id: 'tool-call-2', + type: 'function', + function: { + name: 'second_function', + arguments: '{"param2":"value2"}' + } + } + ] + } + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0].message.tool_calls + ).toHaveLength(2); + expect( + response._data.orchestration_result?.choices[0].message.tool_calls?.[1] + .function.name + ).toBe('second_function'); + }); + }); + + describe('logprobs merging', () => { + it('merges logprobs content arrays', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + module_results: {}, + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'Test' + }, + finish_reason: '', + logprobs: { + content: [{ token: 'Test', logprob: -0.1 }] + } + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + module_results: {}, + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { content: ' message' }, + logprobs: { + content: [{ token: ' message', logprob: -0.2 }] + } + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0]?.logprobs?.content + ).toHaveLength(2); + expect( + response._data.orchestration_result?.choices?.[0]?.logprobs + ?.content?.[1]?.token ?? '' + ).toBe(' message'); + }); + + it('handles missing logprobs gracefully', () => { + const response = + new OrchestrationStreamResponse(); + response._data = { + request_id: 'test-request-123', + module_results: {}, + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'Test' + }, + finish_reason: '' + } + ] + } + }; + + const chunk: CompletionPostResponseStreaming = { + request_id: 'test-request-123', + orchestration_result: { + ...llmBase, + choices: [ + { + index: 0, + delta: { content: ' message' }, + logprobs: { + content: [{ token: ' message', logprob: -0.2 }] + } + } + ] + } + }; + + mergeStreamResponse(response, chunk); + + expect( + response._data.orchestration_result?.choices[0].logprobs?.content + ).toHaveLength(1); + expect( + response._data.orchestration_result?.choices[0].logprobs?.content?.[0] + ?.token ?? '' + ).toBe(' message'); + }); + }); +}); diff --git a/packages/orchestration/src/util/stream.ts b/packages/orchestration/src/util/stream.ts new file mode 100644 index 000000000..a030193f2 --- /dev/null +++ b/packages/orchestration/src/util/stream.ts @@ -0,0 +1,261 @@ +import { createLogger } from '@sap-cloud-sdk/util'; +import type { + ChatDelta, + ChoiceLogprobs, + CompletionPostResponseStreaming, + LlmChoice, + LlmChoiceStreaming, + LlmModuleResult, + LLMModuleResultStreaming, + MessageToolCall, + ModuleResults, + ModuleResultsStreaming, + OrchestrationStreamChunkResponse, + OrchestrationStreamResponse, + ResponseChatMessage, + ToolCallChunk +} from '../index.js'; + +const logger = createLogger({ + package: 'orchestration', + messageContext: 'stream-util' +}); + +/** + * @internal + */ +export function mergeStreamResponse( + response: OrchestrationStreamResponse, + chunk: CompletionPostResponseStreaming +): void { + const data = response._data; + data.request_id = chunk.request_id; + data.module_results = mergeModuleResults( + data.module_results, + chunk.module_results + ); + data.orchestration_result = mergeLlmModule( + data.orchestration_result, + chunk.orchestration_result + ); +} + +function mergeModuleResults( + existing: ModuleResults | undefined, + incoming: ModuleResultsStreaming | undefined +): ModuleResults { + const mergedModuleResults = { ...existing }; + for (const [moduleName, moduleResult] of Object.entries(incoming || {})) { + switch (moduleName) { + case 'llm': + mergedModuleResults[moduleName] = mergeLlmModule( + mergedModuleResults[moduleName], + moduleResult + ); + break; + case 'output_unmasking': + mergedModuleResults[moduleName] = mergeLlmChoices( + mergedModuleResults[moduleName], + moduleResult + ); + break; + default: + mergedModuleResults[moduleName] = moduleResult; + } + } + return mergedModuleResults; +} + +function mergeLlmModule( + existing: LlmModuleResult | undefined, + incoming: LLMModuleResultStreaming | undefined +): LlmModuleResult | undefined { + if (!incoming) { + return existing; + } + const mergedModuleResults = { + ...incoming, + usage: mergeTokenUsage(existing?.usage, incoming.usage), + choices: mergeLlmChoices(existing?.choices, incoming?.choices) + }; + return mergedModuleResults; +} + +function mergeTokenUsage( + existing: + | { prompt_tokens: number; completion_tokens: number; total_tokens: number } + | undefined, + incoming: + | { prompt_tokens: number; completion_tokens: number; total_tokens: number } + | undefined +): { prompt_tokens: number; completion_tokens: number; total_tokens: number } { + if (incoming) { + logger.debug(`Token usage: ${JSON.stringify(incoming)}`); + } + return { + prompt_tokens: incoming?.prompt_tokens ?? existing?.prompt_tokens ?? 0, + completion_tokens: + incoming?.completion_tokens ?? existing?.completion_tokens ?? 0, + total_tokens: incoming?.total_tokens ?? existing?.total_tokens ?? 0 + }; +} + +function mergeLlmChoices( + existing: LlmChoice[] | undefined, + incoming: LlmChoiceStreaming[] | undefined +): LlmChoice[] { + const mergedChoices = [...(existing ?? [])]; + for (const choice of incoming ?? []) { + const existingChoice = mergedChoices.find(c => c.index === choice.index); + if (existingChoice) { + // Merge existing choice with incoming choice + existingChoice.finish_reason = handleFinishReason( + existingChoice.finish_reason, + choice.finish_reason, + choice.index + ); + existingChoice.logprobs = mergeLogProbs( + existingChoice.logprobs, + choice.logprobs + ); + existingChoice.message = mergeMessage( + existingChoice.message, + choice.delta + ); + } else { + // Add new choice + mergedChoices.push(transformStreamingChoice(choice)); + } + } + return mergedChoices; +} + +function mergeMessage( + existing: ResponseChatMessage, + incoming: ChatDelta | undefined +): ResponseChatMessage { + if (!incoming) { + return existing; + } + return { + role: existing.role, + content: existing.content + (incoming.content ?? ''), + tool_calls: mergeToolCalls(existing.tool_calls, incoming.tool_calls), + refusal: incoming.refusal ?? existing.refusal + }; +} + +function mergeToolCalls( + existing: MessageToolCall[] | undefined, + incoming: ToolCallChunk[] | undefined +): MessageToolCall[] | undefined { + if (!incoming || incoming.length === 0) { + return existing; + } + if (!existing || existing.length === 0) { + return transformStreamingToolCalls(incoming); + } + const mergedToolCalls = [...existing]; + for (const toolCall of incoming) { + const existingToolCall = mergedToolCalls.find( + tc => tc.index === toolCall.index + ); + if (existingToolCall) { + // Merge existing tool call with incoming tool call + existingToolCall.function.name = + toolCall.function?.name ?? existingToolCall.function.name; + existingToolCall.function.arguments = + existingToolCall.function.arguments + + (toolCall.function?.arguments ?? ''); + } else { + // Add new tool call + mergedToolCalls.push(transformStreamingToolCall(toolCall)); + } + } + return mergedToolCalls; +} + +function mergeLogProbs( + existing: ChoiceLogprobs | undefined, + incoming: ChoiceLogprobs | undefined +): ChoiceLogprobs | undefined { + if (!incoming) { + return existing; + } + if (!existing) { + return incoming; + } + return { + content: [...(existing.content ?? []), ...(incoming.content ?? [])], + refusal: [...(existing.refusal ?? []), ...(incoming.refusal ?? [])] + }; +} + +function handleFinishReason( + existing: string | undefined, + incoming: string | undefined, + choiceIndex: number +): string { + if (!incoming) { + return existing ?? ''; + } + + switch (incoming) { + case 'content_filter': + logger.error( + `Choice ${choiceIndex}: Stream finished with content filter hit.` + ); + break; + case 'length': + logger.error( + `Choice ${choiceIndex}: Stream finished with token length exceeded.` + ); + break; + case 'stop': + case 'tool_calls': + case 'function_call': + logger.debug(`Choice ${choiceIndex}: Stream finished.`); + break; + default: + logger.error( + `Choice ${choiceIndex}: Stream finished with unknown reason '${incoming}'.` + ); + } + + return incoming; +} + +function transformStreamingChoice(choice: LlmChoiceStreaming): LlmChoice { + return { + index: choice.index, + message: { + role: 'assistant', + content: choice.delta.content, + tool_calls: transformStreamingToolCalls(choice.delta.tool_calls), + refusal: choice.delta.refusal + }, + finish_reason: choice.finish_reason ?? '', + logprobs: choice.logprobs + }; +} + +function transformStreamingToolCalls( + toolCalls: ToolCallChunk[] | undefined +): MessageToolCall[] | undefined { + if (!toolCalls || toolCalls.length === 0) { + return undefined; + } + return toolCalls?.map(toolCall => transformStreamingToolCall(toolCall)); +} + +function transformStreamingToolCall(toolCall: ToolCallChunk): MessageToolCall { + return { + index: toolCall.index, + id: toolCall.id ?? '', + type: toolCall.type ?? 'function', + function: { + name: toolCall.function?.name ?? '', + arguments: toolCall.function?.arguments ?? '' + } + }; +} diff --git a/packages/orchestration/src/util/tool-calls.ts b/packages/orchestration/src/util/tool-calls.ts deleted file mode 100644 index 1129b0cff..000000000 --- a/packages/orchestration/src/util/tool-calls.ts +++ /dev/null @@ -1,78 +0,0 @@ -import type { - MessageToolCall, - ToolCallChunk -} from '../client/api/schema/index.js'; - -/** - * @internal - */ -export type ToolCallAccumulator = { - id?: string; - type: 'function'; - function: { - name?: string; - arguments?: string; - } & Record; -} & Record; - -/** - * @internal - * Check if the accumulator is a MessageToolCall. - */ -export function isMessageToolCall( - acc: ToolCallAccumulator -): acc is MessageToolCall { - return ( - typeof acc.id === 'string' && - typeof acc.function.name === 'string' && - typeof acc.function.arguments === 'string' - ); -} - -/** - * Merge a stream of ToolCallChunk into a single MessageToolCall. - * @throws If the final object is missing required fields. - * @internal - */ -export function mergeToolCallChunk( - chunk: ToolCallChunk, - acc?: ToolCallAccumulator -): ToolCallAccumulator { - const accumulator: ToolCallAccumulator = acc - ? { ...acc } - : { - type: 'function', - function: {} - }; - - if (chunk.id) { - accumulator.id = chunk.id; - } - - // Merge any extra top‐level props - for (const key of Object.keys(chunk)) { - if (!['index', 'id', 'type', 'function'].includes(key)) { - accumulator[key] = chunk[key]; - } - } - - if (chunk.function) { - if (chunk.function.name) { - accumulator.function.name = chunk.function.name; - } - - if (chunk.function.arguments) { - accumulator.function.arguments = - (accumulator.function.arguments || '') + chunk.function.arguments; - } - - // Merge any extra function‐scoped fields - for (const key of Object.keys(chunk.function)) { - if (!['name', 'arguments'].includes(key)) { - accumulator.function[key] = (chunk.function as any)[key]; - } - } - } - - return accumulator; -} diff --git a/tests/e2e-tests/src/orchestration.test.ts b/tests/e2e-tests/src/orchestration.test.ts index 76d588f3d..d37d7e12a 100644 --- a/tests/e2e-tests/src/orchestration.test.ts +++ b/tests/e2e-tests/src/orchestration.test.ts @@ -197,6 +197,7 @@ describe('orchestration', () => { "name": "add", }, "id": "mock_id", + "index": 0, "type": "function", }, { @@ -205,6 +206,7 @@ describe('orchestration', () => { "name": "add", }, "id": "mock_id", + "index": 1, "type": "function", }, ]