Skip to content

Commit aa146bf

Browse files
fix: linux lints (#1789)
* fix: linux lints * store credentials only in sqlite
1 parent d558d85 commit aa146bf

File tree

7 files changed

+103
-269
lines changed

7 files changed

+103
-269
lines changed

crates/chat-cli/src/auth/builder_id.rs

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ use crate::auth::AuthError;
5252
use crate::auth::consts::*;
5353
use crate::auth::scope::is_scopes;
5454
use crate::aws_common::app_name;
55-
use crate::database::Database;
56-
use crate::database::secret_store::{
55+
use crate::database::{
56+
Database,
5757
Secret,
58-
SecretStore,
5958
};
6059

6160
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -123,8 +122,8 @@ impl DeviceRegistration {
123122
}
124123

125124
/// Loads the OIDC registered client from the secret store, deleting it if it is expired.
126-
async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result<Option<Self>, AuthError> {
127-
let device_registration = secret_store.get(Self::SECRET_KEY).await?;
125+
async fn load_from_secret_store(database: &Database, region: &Region) -> Result<Option<Self>, AuthError> {
126+
let device_registration = database.get_secret(Self::SECRET_KEY).await?;
128127

129128
if let Some(device_registration) = device_registration {
130129
// check that the data is not expired, assume it is invalid if not present
@@ -138,7 +137,7 @@ impl DeviceRegistration {
138137
}
139138

140139
// delete the data if its expired or invalid
141-
if let Err(err) = secret_store.delete(Self::SECRET_KEY).await {
140+
if let Err(err) = database.delete_secret(Self::SECRET_KEY).await {
142141
error!(?err, "Failed to delete device registration from keychain");
143142
}
144143

@@ -152,7 +151,7 @@ impl DeviceRegistration {
152151
client: &Client,
153152
region: &Region,
154153
) -> Result<Self, AuthError> {
155-
match Self::load_from_secret_store(&database.secret_store, region).await {
154+
match Self::load_from_secret_store(database, region).await {
156155
Ok(Some(registration)) if registration.oauth_flow == OAuthFlow::DeviceCode => match &registration.scopes {
157156
Some(scopes) if is_scopes(scopes) => return Ok(registration),
158157
_ => warn!("Invalid scopes in device registration, ignoring"),
@@ -181,17 +180,17 @@ impl DeviceRegistration {
181180
SCOPES.iter().map(|s| (*s).to_owned()).collect(),
182181
);
183182

184-
if let Err(err) = device_registration.save(&database.secret_store).await {
183+
if let Err(err) = device_registration.save(database).await {
185184
error!(?err, "Failed to write device registration to keychain");
186185
}
187186

188187
Ok(device_registration)
189188
}
190189

191190
/// Saves to the passed secret store.
192-
pub async fn save(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
191+
pub async fn save(&self, secret_store: &Database) -> Result<(), AuthError> {
193192
secret_store
194-
.set(Self::SECRET_KEY, &serde_json::to_string(&self)?)
193+
.set_secret(Self::SECRET_KEY, &serde_json::to_string(&self)?)
195194
.await?;
196195
Ok(())
197196
}
@@ -285,8 +284,8 @@ impl BuilderIdToken {
285284
}
286285

287286
/// Load the token from the keychain, refresh the token if it is expired and return it
288-
pub async fn load(database: &mut Database) -> Result<Option<Self>, AuthError> {
289-
match database.secret_store.get(Self::SECRET_KEY).await {
287+
pub async fn load(database: &Database) -> Result<Option<Self>, AuthError> {
288+
match database.get_secret(Self::SECRET_KEY).await {
290289
Ok(Some(secret)) => {
291290
let token: Option<Self> = serde_json::from_str(&secret.0)?;
292291
match token {
@@ -296,7 +295,7 @@ impl BuilderIdToken {
296295
let client = client(region.clone());
297296
// if token is expired try to refresh
298297
if token.is_expired() {
299-
token.refresh_token(&client, &database.secret_store, &region).await
298+
token.refresh_token(&client, database, &region).await
300299
} else {
301300
Ok(Some(token))
302301
}
@@ -316,19 +315,19 @@ impl BuilderIdToken {
316315
pub async fn refresh_token(
317316
&self,
318317
client: &Client,
319-
secret_store: &SecretStore,
318+
database: &Database,
320319
region: &Region,
321320
) -> Result<Option<Self>, AuthError> {
322321
let Some(refresh_token) = &self.refresh_token else {
323322
// if the token is expired and has no refresh token, delete it
324-
if let Err(err) = self.delete(secret_store).await {
323+
if let Err(err) = self.delete(database).await {
325324
error!(?err, "Failed to delete builder id token");
326325
}
327326

328327
return Ok(None);
329328
};
330329

331-
let registration = match DeviceRegistration::load_from_secret_store(secret_store, region).await? {
330+
let registration = match DeviceRegistration::load_from_secret_store(database, region).await? {
332331
Some(registration) if registration.oauth_flow == self.oauth_flow => registration,
333332
// If the OIDC client registration is for a different oauth flow or doesn't exist, then
334333
// we can't refresh the token.
@@ -365,7 +364,7 @@ impl BuilderIdToken {
365364
);
366365
debug!("Refreshed access token, new token: {:?}", token);
367366

368-
if let Err(err) = token.save(secret_store).await {
367+
if let Err(err) = token.save(database).await {
369368
error!(?err, "Failed to store builder id access token");
370369
};
371370

@@ -378,7 +377,7 @@ impl BuilderIdToken {
378377
// if the error is the client's fault, clear the token
379378
if let SdkError::ServiceError(service_err) = &err {
380379
if !service_err.err().is_slow_down_exception() {
381-
if let Err(err) = self.delete(secret_store).await {
380+
if let Err(err) = self.delete(database).await {
382381
error!(?err, "Failed to delete builder id token");
383382
}
384383
}
@@ -398,16 +397,16 @@ impl BuilderIdToken {
398397
}
399398

400399
/// Save the token to the keychain
401-
pub async fn save(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
402-
secret_store
403-
.set(Self::SECRET_KEY, &serde_json::to_string(self)?)
400+
pub async fn save(&self, database: &Database) -> Result<(), AuthError> {
401+
database
402+
.set_secret(Self::SECRET_KEY, &serde_json::to_string(self)?)
404403
.await?;
405404
Ok(())
406405
}
407406

408407
/// Delete the token from the keychain
409-
pub async fn delete(&self, secret_store: &SecretStore) -> Result<(), AuthError> {
410-
secret_store.delete(Self::SECRET_KEY).await?;
408+
pub async fn delete(&self, database: &Database) -> Result<(), AuthError> {
409+
database.delete_secret(Self::SECRET_KEY).await?;
411410
Ok(())
412411
}
413412

@@ -479,7 +478,7 @@ pub async fn poll_create_token(
479478
let token: BuilderIdToken =
480479
BuilderIdToken::from_output(output, region, start_url, OAuthFlow::DeviceCode, scopes);
481480

482-
if let Err(err) = token.save(&database.secret_store).await {
481+
if let Err(err) = token.save(database).await {
483482
error!(?err, "Failed to store builder id token");
484483
};
485484

@@ -500,13 +499,13 @@ pub async fn is_logged_in(database: &mut Database) -> bool {
500499
}
501500

502501
pub async fn logout(database: &mut Database) -> Result<(), AuthError> {
503-
let Ok(secret_store) = SecretStore::new().await else {
502+
let Ok(secret_store) = Database::new().await else {
504503
return Ok(());
505504
};
506505

507506
let (builder_res, device_res) = tokio::join!(
508-
secret_store.delete(BuilderIdToken::SECRET_KEY),
509-
secret_store.delete(DeviceRegistration::SECRET_KEY),
507+
secret_store.delete_secret(BuilderIdToken::SECRET_KEY),
508+
secret_store.delete_secret(DeviceRegistration::SECRET_KEY),
510509
);
511510

512511
let profile_res = database.unset_auth_profile();

crates/chat-cli/src/auth/pkce.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ impl PkceRegistration {
231231
/// then the access and refresh tokens will be saved.
232232
///
233233
/// Only the first connection will be served.
234-
pub async fn finish<C: PkceClient>(self, client: &C, database: Option<&Database>) -> Result<(), AuthError> {
234+
pub async fn finish<C: PkceClient>(self, client: &C, database: Option<&mut Database>) -> Result<(), AuthError> {
235235
let code = tokio::select! {
236236
code = Self::recv_code(self.listener, self.state) => {
237237
code?
@@ -270,11 +270,11 @@ impl PkceRegistration {
270270
);
271271

272272
if let Some(database) = database {
273-
if let Err(err) = device_registration.save(&database.secret_store).await {
273+
if let Err(err) = device_registration.save(database).await {
274274
error!(?err, "Failed to store pkce registration to secret store");
275275
}
276276

277-
if let Err(err) = token.save(&database.secret_store).await {
277+
if let Err(err) = token.save(database).await {
278278
error!(?err, "Failed to store builder id token");
279279
};
280280
}

crates/chat-cli/src/cli/debug.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use clap::{
2-
Subcommand,
3-
ValueEnum,
4-
};
1+
use clap::ValueEnum;
52

63
#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)]
74
pub enum Build {
@@ -65,7 +62,7 @@ pub enum TISAction {
6562
use std::path::PathBuf;
6663

6764
#[cfg(target_os = "macos")]
68-
#[derive(Debug, Subcommand, Clone, PartialEq, Eq)]
65+
#[derive(Debug, clap::Subcommand, Clone, PartialEq, Eq)]
6966
pub enum InputMethodDebugAction {
7067
Install {
7168
bundle_path: Option<PathBuf>,

crates/chat-cli/src/database/mod.rs

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
pub mod secret_store;
21
pub mod settings;
32

43
use std::ops::Deref;
@@ -17,7 +16,6 @@ use rusqlite::{
1716
ToSql,
1817
params,
1918
};
20-
use secret_store::SecretStore;
2119
use serde::de::DeserializeOwned;
2220
use serde::{
2321
Deserialize,
@@ -94,6 +92,25 @@ impl From<amzn_codewhisperer_client::types::Profile> for AuthProfile {
9492
}
9593
}
9694

95+
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
96+
#[serde(transparent)]
97+
pub struct Secret(pub String);
98+
99+
impl std::fmt::Debug for Secret {
100+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101+
f.debug_struct("Secret").finish()
102+
}
103+
}
104+
105+
impl<T> From<T> for Secret
106+
where
107+
T: Into<String>,
108+
{
109+
fn from(value: T) -> Self {
110+
Self(value.into())
111+
}
112+
}
113+
97114
// A cloneable error
98115
#[derive(Debug, Clone, thiserror::Error)]
99116
#[error("Failed to open database: {}", .0)]
@@ -117,9 +134,6 @@ pub enum DatabaseError {
117134
DbOpenError(#[from] DbOpenError),
118135
#[error("{}", .0)]
119136
PoisonError(String),
120-
#[cfg(target_os = "macos")]
121-
#[error("Security error: {}", .0)]
122-
Security(String),
123137
#[error(transparent)]
124138
StringFromUtf8(#[from] std::string::FromUtf8Error),
125139
#[error(transparent)]
@@ -140,8 +154,7 @@ pub enum Table {
140154
State,
141155
/// The conversations tables contains user chat conversations.
142156
Conversations,
143-
#[cfg(not(target_os = "macos"))]
144-
/// The auth table contains
157+
/// The auth table contains SSO and Builder ID credentials.
145158
Auth,
146159
}
147160

@@ -150,7 +163,6 @@ impl std::fmt::Display for Table {
150163
match self {
151164
Table::State => write!(f, "state"),
152165
Table::Conversations => write!(f, "conversations"),
153-
#[cfg(not(target_os = "macos"))]
154166
Table::Auth => write!(f, "auth_kv"),
155167
}
156168
}
@@ -166,7 +178,6 @@ struct Migration {
166178
pub struct Database {
167179
pool: Pool<SqliteConnectionManager>,
168180
pub settings: Settings,
169-
pub secret_store: SecretStore,
170181
}
171182

172183
impl Database {
@@ -176,7 +187,6 @@ impl Database {
176187
return Self {
177188
pool: Pool::builder().build(SqliteConnectionManager::memory()).unwrap(),
178189
settings: Settings::new().await?,
179-
secret_store: SecretStore::new().await?,
180190
}
181191
.migrate();
182192
},
@@ -209,7 +219,6 @@ impl Database {
209219
Ok(Self {
210220
pool,
211221
settings: Settings::new().await?,
212-
secret_store: SecretStore::new().await?,
213222
}
214223
.migrate()
215224
.map_err(|e| DbOpenError(e.to_string()))?)
@@ -419,6 +428,19 @@ impl Database {
419428

420429
Ok(map)
421430
}
431+
432+
pub async fn get_secret(&self, key: &str) -> Result<Option<Secret>, DatabaseError> {
433+
Ok(self.get_entry::<String>(Table::Auth, key)?.map(Into::into))
434+
}
435+
436+
pub async fn set_secret(&self, key: &str, value: &str) -> Result<(), DatabaseError> {
437+
self.set_entry(Table::Auth, key, value)?;
438+
Ok(())
439+
}
440+
441+
pub async fn delete_secret(&self, key: &str) -> Result<(), DatabaseError> {
442+
self.delete_entry(Table::Auth, key)
443+
}
422444
}
423445

424446
fn max_migration<C: Deref<Target = Connection>>(conn: &C) -> Option<i64> {
@@ -502,4 +524,43 @@ mod tests {
502524
assert!(db.get_entry::<f32>(Table::State, "float").unwrap().is_some());
503525
assert!(db.get_entry::<bool>(Table::State, "bool").unwrap().is_some());
504526
}
527+
528+
#[tokio::test]
529+
#[ignore = "not on ci"]
530+
async fn test_set_password() {
531+
let key = "test_set_password";
532+
let store = Database::new().await.unwrap();
533+
store.set_secret(key, "test").await.unwrap();
534+
assert_eq!(store.get_secret(key).await.unwrap().unwrap().0, "test");
535+
store.delete_secret(key).await.unwrap();
536+
}
537+
538+
#[tokio::test]
539+
#[ignore = "not on ci"]
540+
async fn secret_get_time() {
541+
let key = "test_secret_get_time";
542+
let store = Database::new().await.unwrap();
543+
store.set_secret(key, "1234").await.unwrap();
544+
545+
let now = std::time::Instant::now();
546+
for _ in 0..100 {
547+
store.get_secret(key).await.unwrap();
548+
}
549+
550+
println!("duration: {:?}", now.elapsed() / 100);
551+
552+
store.delete_secret(key).await.unwrap();
553+
}
554+
555+
#[tokio::test]
556+
#[ignore = "not on ci"]
557+
async fn secret_delete() {
558+
let key = "test_secret_delete";
559+
560+
let store = Database::new().await.unwrap();
561+
store.set_secret(key, "1234").await.unwrap();
562+
assert_eq!(store.get_secret(key).await.unwrap().unwrap().0, "1234");
563+
store.delete_secret(key).await.unwrap();
564+
assert_eq!(store.get_secret(key).await.unwrap(), None);
565+
}
505566
}

0 commit comments

Comments
 (0)