Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions rig-core/src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::agent::Agent;
use crate::client::ProviderClient;
use crate::completion::{CompletionRequest, Message};
use crate::completion::{CompletionRequest, GetTokenUsage, Message, Usage};
use crate::embeddings::embedding::EmbeddingModelDyn;
use crate::providers::{
anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
Expand All @@ -9,6 +9,7 @@
use crate::streaming::StreamingCompletionResponse;
use crate::transcription::TranscriptionModelDyn;
use rig::completion::CompletionModelDyn;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::panic::{RefUnwindSafe, UnwindSafe};
use thiserror::Error;
Expand Down Expand Up @@ -390,7 +391,7 @@
provider: &str,
model: &str,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand All @@ -400,10 +401,10 @@
))?;

let model = completion.completion_model(model);
model

Check failure on line 404 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 404 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 404 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 404 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

mismatched types

Check failure on line 404 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / doc

mismatched types
.stream(request)
.await
.map_err(|e| ClientBuildError::FactoryError(e.to_string()))

Check failure on line 407 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

mismatched types

error[E0308]: mismatched types --> rig-core/src/client/builder.rs:404:9 | 394 | ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> { | ------------------------------------------------------------------------------ expected `Result<StreamingCompletionResponse<FinalCompletionResponse>, ...>` because of return type ... 404 | / model 405 | | .stream(request) 406 | | .await 407 | | .map_err(|e| ClientBuildError::FactoryError(e.to_string())) | |_______________________________________________________________________^ expected `client::builder::FinalCompletionResponse`, found `()` | = note: expected enum `std::result::Result<streaming::StreamingCompletionResponse<client::builder::FinalCompletionResponse>, _>` found enum `std::result::Result<streaming::StreamingCompletionResponse<()>, _>` = note: the full name for the type has been written to '/home/runner/work/rig/rig/target/debug/deps/rig-c15a693386b1818a.long-type-3936704586244979514.txt' = note: consider using `--verbose` to print the full type name to the console

Check failure on line 407 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

mismatched types

error[E0308]: mismatched types --> rig-core/src/client/builder.rs:404:9 | 394 | ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> { | ------------------------------------------------------------------------------ expected `Result<StreamingCompletionResponse<FinalCompletionResponse>, ...>` because of return type ... 404 | / model 405 | | .stream(request) 406 | | .await 407 | | .map_err(|e| ClientBuildError::FactoryError(e.to_string())) | |_______________________________________________________________________^ expected `client::builder::FinalCompletionResponse`, found `()` | = note: expected enum `std::result::Result<streaming::StreamingCompletionResponse<client::builder::FinalCompletionResponse>, _>` found enum `std::result::Result<streaming::StreamingCompletionResponse<()>, _>` = note: the full name for the type has been written to '/home/runner/work/rig/rig/target/debug/deps/rig-57aaafb7bf6954ca.long-type-9287223982256624422.txt' = note: consider using `--verbose` to print the full type name to the console
}

/// Stream a simple prompt to the specified provider and model.
Expand All @@ -420,7 +421,7 @@
provider: &str,
model: &str,
prompt: impl Into<Message> + Send,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand All @@ -441,7 +442,7 @@
chat_history: crate::OneOrMany::one(prompt.into()),
};

model

Check failure on line 445 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 445 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 445 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

mismatched types

Check failure on line 445 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / doc

mismatched types
.stream(request)
.await
.map_err(|e| ClientBuildError::FactoryError(e.to_string()))

Check failure on line 448 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

mismatched types

error[E0308]: mismatched types --> rig-core/src/client/builder.rs:445:9 | 424 | ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> { | ------------------------------------------------------------------------------ expected `Result<StreamingCompletionResponse<FinalCompletionResponse>, ...>` because of return type ... 445 | / model 446 | | .stream(request) 447 | | .await 448 | | .map_err(|e| ClientBuildError::FactoryError(e.to_string())) | |_______________________________________________________________________^ expected `client::builder::FinalCompletionResponse`, found `()` | = note: expected enum `std::result::Result<streaming::StreamingCompletionResponse<client::builder::FinalCompletionResponse>, _>` found enum `std::result::Result<streaming::StreamingCompletionResponse<()>, _>` = note: the full name for the type has been written to '/home/runner/work/rig/rig/target/debug/deps/rig-57aaafb7bf6954ca.long-type-9287223982256624422.txt' = note: consider using `--verbose` to print the full type name to the console
Expand All @@ -463,7 +464,7 @@
model: &str,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
let client = self.build(provider)?;
let completion = client
.as_completion()
Expand All @@ -488,13 +489,25 @@
.unwrap_or_else(|_| crate::OneOrMany::one(Message::user(""))),
};

model

Check failure on line 492 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 492 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / test

mismatched types

Check failure on line 492 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

mismatched types

Check failure on line 492 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / doc

mismatched types
.stream(request)
.await
.map_err(|e| ClientBuildError::FactoryError(e.to_string()))

Check failure on line 495 in rig-core/src/client/builder.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

mismatched types

error[E0308]: mismatched types --> rig-core/src/client/builder.rs:492:9 | 467 | ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> { | ------------------------------------------------------------------------------ expected `Result<StreamingCompletionResponse<FinalCompletionResponse>, ...>` because of return type ... 492 | / model 493 | | .stream(request) 494 | | .await 495 | | .map_err(|e| ClientBuildError::FactoryError(e.to_string())) | |_______________________________________________________________________^ expected `client::builder::FinalCompletionResponse`, found `()` | = note: expected enum `std::result::Result<streaming::StreamingCompletionResponse<client::builder::FinalCompletionResponse>, _>` found enum `std::result::Result<streaming::StreamingCompletionResponse<()>, _>` = note: the full name for the type has been written to '/home/runner/work/rig/rig/target/debug/deps/rig-57aaafb7bf6954ca.long-type-9287223982256624422.txt' = note: consider using `--verbose` to print the full type name to the console
}
}

/// The final streaming response from a dynamic client.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct FinalCompletionResponse {
pub usage: Option<Usage>,
}

impl GetTokenUsage for FinalCompletionResponse {
fn token_usage(&self) -> Option<Usage> {
self.usage
}
}

pub struct ProviderModelId<'builder, 'id> {
builder: &'builder DynClientBuilder,
provider: &'id str,
Expand Down Expand Up @@ -528,7 +541,7 @@
pub async fn stream_completion(
self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_completion(self.provider, self.model, request)
.await
Expand All @@ -544,7 +557,7 @@
pub async fn stream_prompt(
self,
prompt: impl Into<Message> + Send,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_prompt(self.provider, self.model, prompt)
.await
Expand All @@ -562,7 +575,7 @@
self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
self.builder
.stream_chat(self.provider, self.model, prompt, chat_history)
.await
Expand Down
3 changes: 2 additions & 1 deletion rig-core/src/client/completion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::agent::AgentBuilder;
use crate::client::builder::FinalCompletionResponse;
use crate::client::{AsCompletion, ProviderClient};
use crate::completion::{
CompletionError, CompletionModel, CompletionModelDyn, CompletionRequest, CompletionResponse,
Expand Down Expand Up @@ -68,7 +69,7 @@

impl CompletionModel for CompletionModelHandle<'_> {
type Response = ();
type StreamingResponse = ();
type StreamingResponse = FinalCompletionResponse;

fn completion(
&self,
Expand All @@ -81,7 +82,7 @@
fn stream(
&self,
request: CompletionRequest,
) -> impl Future<

Check failure on line 85 in rig-core/src/client/completion.rs

View workflow job for this annotation

GitHub Actions / stable / test

expected `Pin<Box<dyn Future<Output = Result<StreamingCompletionResponse<()>, CompletionError>> + Send>>` to be a future that resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`

Check failure on line 85 in rig-core/src/client/completion.rs

View workflow job for this annotation

GitHub Actions / stable / test

expected `Pin<Box<dyn Future<Output = Result<StreamingCompletionResponse<()>, CompletionError>> + Send>>` to be a future that resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`

Check failure on line 85 in rig-core/src/client/completion.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

expected `Pin<Box<dyn Future<Output = Result<StreamingCompletionResponse<()>, CompletionError>>>>` to be a future that resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`

Check failure on line 85 in rig-core/src/client/completion.rs

View workflow job for this annotation

GitHub Actions / stable / doc

expected `Pin<Box<dyn Future<Output = Result<StreamingCompletionResponse<()>, CompletionError>> + Send>>` to be a future that resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`
Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
> + WasmCompatSend {
self.inner.stream(request)
Expand Down
1 change: 1 addition & 0 deletions rig-core/src/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
//! the individual traits, structs, and enums defined in this module.

use super::message::{AssistantContent, DocumentMediaType};
use crate::client::builder::FinalCompletionResponse;

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

unused import: `crate::client::builder::FinalCompletionResponse`

error: unused import: `crate::client::builder::FinalCompletionResponse` --> rig-core/src/completion/request.rs:67:5 | 67 | use crate::client::builder::FinalCompletionResponse; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]`

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

unused import: `crate::client::builder::FinalCompletionResponse`

error: unused import: `crate::client::builder::FinalCompletionResponse` --> rig-core/src/completion/request.rs:67:5 | 67 | use crate::client::builder::FinalCompletionResponse; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]`

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / clippy

unused import: `crate::client::builder::FinalCompletionResponse`

error: unused import: `crate::client::builder::FinalCompletionResponse` --> rig-core/src/completion/request.rs:67:5 | 67 | use crate::client::builder::FinalCompletionResponse; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]`

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / test

unused import: `crate::client::builder::FinalCompletionResponse`

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

unused import: `crate::client::builder::FinalCompletionResponse`

Check failure on line 67 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / doc

unused import: `crate::client::builder::FinalCompletionResponse`
use crate::client::completion::CompletionModelHandle;
use crate::message::ToolChoice;
use crate::streaming::StreamingCompletionResponse;
Expand Down Expand Up @@ -395,7 +396,7 @@
&self,
request: CompletionRequest,
) -> WasmBoxedFuture<'_, Result<StreamingCompletionResponse<()>, CompletionError>> {
Box::pin(async move {

Check failure on line 399 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / check rig-core wasm target

expected `{async block@rig-core/src/completion/request.rs:399:18: 399:28}` to be a future that resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, _>`

Check failure on line 399 in rig-core/src/completion/request.rs

View workflow job for this annotation

GitHub Actions / stable / doc

expected `{async block@rig-core/src/completion/request.rs:399:18: 399:28}` to be a future that resolves to `Result<StreamingCompletionResponse<()>, CompletionError>`, but it resolves to `Result<StreamingCompletionResponse<FinalCompletionResponse>, _>`
let resp = self.stream(request).await?;
let inner = resp.inner;

Expand Down
15 changes: 9 additions & 6 deletions rig-core/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use crate::OneOrMany;
use crate::agent::Agent;
use crate::agent::prompt_request::streaming::StreamingPromptRequest;
use crate::client::builder::FinalCompletionResponse;
use crate::completion::{
CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
Message, Usage,
Expand Down Expand Up @@ -312,12 +313,12 @@ pub trait StreamingCompletion<M: CompletionModel> {
) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
}

pub(crate) struct StreamingResultDyn<R: Clone + Unpin> {
pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
pub(crate) inner: StreamingResult<R>,
}

impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
type Item = Result<RawStreamingChoice<()>, CompletionError>;
impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
Expand All @@ -327,9 +328,11 @@ impl<R: Clone + Unpin> Stream for StreamingResultDyn<R> {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(chunk))) => match chunk {
RawStreamingChoice::FinalResponse(_) => {
Poll::Ready(Some(Ok(RawStreamingChoice::FinalResponse(()))))
}
RawStreamingChoice::FinalResponse(res) => Poll::Ready(Some(Ok(
RawStreamingChoice::FinalResponse(FinalCompletionResponse {
usage: res.token_usage(),
}),
))),
RawStreamingChoice::Message(m) => {
Poll::Ready(Some(Ok(RawStreamingChoice::Message(m))))
}
Expand Down
Loading