Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rig-core/src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ impl<'a> DynClientBuilder {
provider: &str,
model: &str,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand Down Expand Up @@ -420,7 +420,7 @@ impl<'a> DynClientBuilder {
provider: &str,
model: &str,
prompt: impl Into<Message> + Send,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand Down Expand Up @@ -463,7 +463,7 @@ impl<'a> DynClientBuilder {
model: &str,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand Down Expand Up @@ -528,7 +528,7 @@ impl<'builder> ProviderModelId<'builder, '_> {
pub async fn stream_completion(
self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_completion(self.provider, self.model, request)
.await
Expand All @@ -544,7 +544,7 @@ impl<'builder> ProviderModelId<'builder, '_> {
pub async fn stream_prompt(
self,
prompt: impl Into<Message> + Send,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_prompt(self.provider, self.model, prompt)
.await
Expand All @@ -562,7 +562,7 @@ impl<'builder> ProviderModelId<'builder, '_> {
self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_chat(self.provider, self.model, prompt, chat_history)
.await
Expand Down Expand Up @@ -643,7 +643,7 @@ mod audio {
}
}
use crate::agent::AgentBuilder;
use crate::client::completion::CompletionModelHandle;
use crate::client::completion::{CompletionModelHandle, FinalCompletionResponse};
#[cfg(feature = "audio")]
pub use audio::*;
use rig::providers::mistral;
Expand Down
16 changes: 14 additions & 2 deletions rig-core/src/client/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::agent::AgentBuilder;
use crate::client::{AsCompletion, ProviderClient};
use crate::completion::{
CompletionError, CompletionModel, CompletionModelDyn, CompletionRequest, CompletionResponse,
GetTokenUsage,
GetTokenUsage, Usage,
};
use crate::extractor::ExtractorBuilder;
use crate::streaming::StreamingCompletionResponse;
Expand Down Expand Up @@ -59,6 +59,18 @@ pub trait CompletionClient: ProviderClient + Clone {
}
}

/// The final streaming response from a dynamic client.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct FinalCompletionResponse {
pub usage: Option<Usage>,
}

impl GetTokenUsage for FinalCompletionResponse {
fn token_usage(&self) -> Option<Usage> {
self.usage
}
}

/// Wraps a CompletionModel in a dyn-compatible way for AgentBuilder.
#[derive(Clone)]
pub struct CompletionModelHandle<'a> {
Expand All @@ -67,7 +79,7 @@ pub struct CompletionModelHandle<'a> {

impl CompletionModel for CompletionModelHandle<'_> {
type Response = ();
type StreamingResponse = ();
type StreamingResponse = FinalCompletionResponse;

fn completion(
&self,
Expand Down
7 changes: 4 additions & 3 deletions rig-core/src/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
//! the individual traits, structs, and enums defined in this module.

use super::message::{AssistantContent, DocumentMediaType};
use crate::client::completion::CompletionModelHandle;
use crate::client::completion::{CompletionModelHandle, FinalCompletionResponse};
use crate::message::ToolChoice;
use crate::streaming::StreamingCompletionResponse;
use crate::{OneOrMany, streaming};
Expand Down Expand Up @@ -356,7 +356,7 @@ pub trait CompletionModelDyn: Send + Sync {
fn stream(
&self,
request: CompletionRequest,
) -> BoxFuture<'_, Result<StreamingCompletionResponse<()>, CompletionError>>;
) -> BoxFuture<'_, Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>>;

fn completion_request(
&self,
Expand Down Expand Up @@ -387,7 +387,8 @@ where
fn stream(
&self,
request: CompletionRequest,
) -> BoxFuture<'_, Result<StreamingCompletionResponse<()>, CompletionError>> {
) -> BoxFuture<'_, Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>>
{
Box::pin(async move {
let resp = self.stream(request).await?;
let inner = resp.inner;
Expand Down
15 changes: 9 additions & 6 deletions rig-core/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use crate::OneOrMany;
use crate::agent::Agent;
use crate::agent::prompt_request::streaming::StreamingPromptRequest;
use crate::client::completion::FinalCompletionResponse;
use crate::completion::{
CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
Message, Usage,
Expand Down Expand Up @@ -309,12 +310,12 @@ pub trait StreamingCompletion<M: CompletionModel> {
) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
}

pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
pub(crate) inner: StreamingResult<R>,
}

impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
type Item = Result<RawStreamingChoice<()>, CompletionError>;
impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
Expand All @@ -324,9 +325,11 @@ impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(chunk))) => match chunk {
RawStreamingChoice::FinalResponse(_) => {
Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
}
RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
RawStreamingChoice::FinalResponse(FinalCompletionResponse {
usage: res.token_usage(),
}),
))),
RawStreamingChoice::Message(m) => {
Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
}
Expand Down