Skip to content

Commit 5092f0f

Browse files
authored
collab: Sync model request overages to Stripe (#29583)
This PR adds syncing of model request overages to Stripe. Release Notes: - N/A
1 parent 3a212e7 commit 5092f0f

File tree

9 files changed

+318
-16
lines changed

9 files changed

+318
-16
lines changed

crates/collab/src/api/billing.rs

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ async fn create_billing_subscription(
393393
zed_llm_client::LanguageModelProvider::Anthropic,
394394
"claude-3-7-sonnet",
395395
)?;
396-
let stripe_model = stripe_billing.register_model(default_model).await?;
396+
let stripe_model = stripe_billing
397+
.register_model_for_token_based_usage(default_model)
398+
.await?;
397399
stripe_billing
398400
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
399401
.await?
@@ -1303,7 +1305,9 @@ async fn sync_token_usage_with_stripe(
13031305
.parse()
13041306
.context("failed to parse stripe customer id from db")?;
13051307

1306-
let stripe_model = stripe_billing.register_model(&model).await?;
1308+
let stripe_model = stripe_billing
1309+
.register_model_for_token_based_usage(&model)
1310+
.await?;
13071311
stripe_billing
13081312
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
13091313
.await?;
@@ -1315,3 +1319,106 @@ async fn sync_token_usage_with_stripe(
13151319

13161320
Ok(())
13171321
}
1322+
1323+
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1324+
1325+
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1326+
let Some(stripe_billing) = app.stripe_billing.clone() else {
1327+
log::warn!("failed to retrieve Stripe billing object");
1328+
return;
1329+
};
1330+
let Some(llm_db) = app.llm_db.clone() else {
1331+
log::warn!("failed to retrieve LLM database");
1332+
return;
1333+
};
1334+
1335+
let executor = app.executor.clone();
1336+
executor.spawn_detached({
1337+
let executor = executor.clone();
1338+
async move {
1339+
loop {
1340+
sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1341+
.await
1342+
.context("failed to sync LLM request usage to Stripe")
1343+
.trace_err();
1344+
executor
1345+
.sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1346+
.await;
1347+
}
1348+
}
1349+
});
1350+
}
1351+
1352+
async fn sync_model_request_usage_with_stripe(
1353+
app: &Arc<AppState>,
1354+
llm_db: &Arc<LlmDatabase>,
1355+
stripe_billing: &Arc<StripeBilling>,
1356+
) -> anyhow::Result<()> {
1357+
let usage_meters = llm_db
1358+
.get_current_subscription_usage_meters(Utc::now())
1359+
.await?;
1360+
let user_ids = usage_meters
1361+
.iter()
1362+
.map(|(_, usage)| usage.user_id)
1363+
.collect::<HashSet<UserId>>();
1364+
let billing_subscriptions = app
1365+
.db
1366+
.get_active_zed_pro_billing_subscriptions(user_ids)
1367+
.await?;
1368+
1369+
let claude_3_5_sonnet = stripe_billing
1370+
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
1371+
.await?;
1372+
let claude_3_7_sonnet = stripe_billing
1373+
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
1374+
.await?;
1375+
1376+
for (usage_meter, usage) in usage_meters {
1377+
maybe!(async {
1378+
let Some((billing_customer, billing_subscription)) =
1379+
billing_subscriptions.get(&usage.user_id)
1380+
else {
1381+
bail!(
1382+
"Attempted to sync usage meter for user who is not a Stripe customer: {}",
1383+
usage.user_id
1384+
);
1385+
};
1386+
1387+
let stripe_customer_id = billing_customer
1388+
.stripe_customer_id
1389+
.parse::<stripe::CustomerId>()
1390+
.context("failed to parse Stripe customer ID from database")?;
1391+
let stripe_subscription_id = billing_subscription
1392+
.stripe_subscription_id
1393+
.parse::<stripe::SubscriptionId>()
1394+
.context("failed to parse Stripe subscription ID from database")?;
1395+
1396+
let model = llm_db.model_by_id(usage_meter.model_id)?;
1397+
1398+
let (price_id, meter_event_name) = match model.name.as_str() {
1399+
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
1400+
"claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
1401+
model_name => {
1402+
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1403+
}
1404+
};
1405+
1406+
stripe_billing
1407+
.subscribe_to_price(&stripe_subscription_id, price_id)
1408+
.await?;
1409+
stripe_billing
1410+
.bill_model_request_usage(
1411+
&stripe_customer_id,
1412+
meter_event_name,
1413+
usage_meter.requests,
1414+
)
1415+
.await?;
1416+
1417+
Ok(())
1418+
})
1419+
.await
1420+
.log_err();
1421+
}
1422+
1423+
Ok(())
1424+
}

crates/collab/src/db/queries/billing_subscriptions.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,38 @@ impl Database {
191191
.await
192192
}
193193

194+
pub async fn get_active_zed_pro_billing_subscriptions(
195+
&self,
196+
user_ids: HashSet<UserId>,
197+
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
198+
self.transaction(|tx| {
199+
let user_ids = user_ids.clone();
200+
async move {
201+
let mut rows = billing_subscription::Entity::find()
202+
.inner_join(billing_customer::Entity)
203+
.select_also(billing_customer::Entity)
204+
.filter(billing_customer::Column::UserId.is_in(user_ids))
205+
.filter(
206+
billing_subscription::Column::StripeSubscriptionStatus
207+
.eq(StripeSubscriptionStatus::Active),
208+
)
209+
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
210+
.order_by_asc(billing_subscription::Column::Id)
211+
.stream(&*tx)
212+
.await?;
213+
214+
let mut subscriptions = HashMap::default();
215+
while let Some(row) = rows.next().await {
216+
if let (subscription, Some(customer)) = row? {
217+
subscriptions.insert(customer.user_id, (customer, subscription));
218+
}
219+
}
220+
Ok(subscriptions)
221+
}
222+
})
223+
.await
224+
}
225+
194226
/// Returns whether the user has an active billing subscription.
195227
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
196228
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)

crates/collab/src/llm/db/queries.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ use super::*;
22

33
pub mod billing_events;
44
pub mod providers;
5+
pub mod subscription_usage_meters;
56
pub mod subscription_usages;
67
pub mod usages;
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
2+
3+
use super::*;
4+
5+
impl LlmDatabase {
6+
/// Returns all current subscription usage meters as of the given timestamp.
7+
pub async fn get_current_subscription_usage_meters(
8+
&self,
9+
now: DateTimeUtc,
10+
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
11+
let now = convert_chrono_to_time(now)?;
12+
13+
self.transaction(|tx| async move {
14+
let result = subscription_usage_meter::Entity::find()
15+
.inner_join(subscription_usage::Entity)
16+
.filter(
17+
subscription_usage::Column::PeriodStartAt
18+
.lte(now)
19+
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
20+
)
21+
.select_also(subscription_usage::Entity)
22+
.all(&*tx)
23+
.await?;
24+
25+
let result = result
26+
.into_iter()
27+
.filter_map(|(meter, usage)| {
28+
let usage = usage?;
29+
Some((meter, usage))
30+
})
31+
.collect();
32+
33+
Ok(result)
34+
})
35+
.await
36+
}
37+
}

crates/collab/src/llm/db/queries/subscription_usages.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::db::{UserId, billing_subscription};
66

77
use super::*;
88

9-
fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
9+
pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
1010
use chrono::{Datelike as _, Timelike as _};
1111

1212
let date = time::Date::from_calendar_date(

crates/collab/src/llm/db/tables.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ pub mod model;
33
pub mod monthly_usage;
44
pub mod provider;
55
pub mod subscription_usage;
6+
pub mod subscription_usage_meter;
67
pub mod usage;
78
pub mod usage_measure;
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use sea_orm::entity::prelude::*;
2+
3+
use crate::llm::db::ModelId;
4+
5+
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
6+
#[sea_orm(table_name = "subscription_usage_meters")]
7+
pub struct Model {
8+
#[sea_orm(primary_key)]
9+
pub id: i32,
10+
pub subscription_usage_id: i32,
11+
pub model_id: ModelId,
12+
pub requests: i32,
13+
}
14+
15+
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
16+
pub enum Relation {
17+
#[sea_orm(
18+
belongs_to = "super::subscription_usage::Entity",
19+
from = "Column::SubscriptionUsageId",
20+
to = "super::subscription_usage::Column::Id"
21+
)]
22+
SubscriptionUsage,
23+
#[sea_orm(
24+
belongs_to = "super::model::Entity",
25+
from = "Column::ModelId",
26+
to = "super::model::Column::Id"
27+
)]
28+
Model,
29+
}
30+
31+
impl Related<super::subscription_usage::Entity> for Entity {
32+
fn to() -> RelationDef {
33+
Relation::SubscriptionUsage.def()
34+
}
35+
}
36+
37+
impl Related<super::model::Entity> for Entity {
38+
fn to() -> RelationDef {
39+
Relation::Model.def()
40+
}
41+
}
42+
43+
impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/main.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use axum::{
88
};
99

1010
use collab::api::CloudflareIpCountryHeader;
11-
use collab::api::billing::sync_llm_token_usage_with_stripe_periodically;
11+
use collab::api::billing::{
12+
sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
13+
};
1214
use collab::llm::db::LlmDatabase;
1315
use collab::migrations::run_database_migrations;
1416
use collab::user_backfiller::spawn_user_backfiller;
@@ -152,6 +154,7 @@ async fn main() -> Result<()> {
152154

153155
if let Some(mut llm_db) = llm_db {
154156
llm_db.initialize().await?;
157+
sync_llm_request_usage_with_stripe_periodically(state.clone());
155158
sync_llm_token_usage_with_stripe_periodically(state.clone());
156159
}
157160

0 commit comments

Comments
 (0)