Skip to content
4 changes: 4 additions & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ pub trait Connection: DynClone + Send + Sync {
))
}

async fn begin(&self) -> Result<()>;
async fn commit(&self) -> Result<()>;
async fn rollback(&self) -> Result<()>;

async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
let mut total_count: usize = 0;
let mut total_size: usize = 0;
Expand Down
17 changes: 16 additions & 1 deletion driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ impl Connection for FlightSQLConnection {
"STREAM LOAD unavailable for FlightSQL".to_string(),
))
}

async fn begin(&self) -> Result<()> {
self.exec("BEGIN").await.unwrap();
Ok(())
}

async fn commit(&self) -> Result<()> {
self.exec("COMMIT").await.unwrap();
Ok(())
}

async fn rollback(&self) -> Result<()> {
self.exec("ROLLBACK").await.unwrap();
Ok(())
}
}

impl FlightSQLConnection {
Expand Down Expand Up @@ -273,7 +288,7 @@ impl Args {
return Err(Error::BadArgument(format!(
"Invalid value for sslmode: {}",
v.as_ref()
)))
)));
}
},
"tls_ca_file" => args.tls_ca_file = Some(v.to_string()),
Expand Down
15 changes: 15 additions & 0 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ impl Connection for RestAPIConnection {
let stats = self.load_data(sql, reader, size, None, None).await?;
Ok(stats)
}

async fn begin(&self) -> Result<()> {
self.exec("BEGIN").await.unwrap();
Ok(())
}

async fn commit(&self) -> Result<()> {
self.exec("COMMIT").await.unwrap();
Ok(())
}

async fn rollback(&self) -> Result<()> {
self.exec("ROLLBACK").await.unwrap();
Ok(())
}
}

impl<'o> RestAPIConnection {
Expand Down
1 change: 1 addition & 0 deletions driver/tests/driver/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ mod load;
mod select_iter;
mod select_simple;
mod session;
mod transaction;
57 changes: 57 additions & 0 deletions driver/tests/driver/transaction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use databend_driver::Client;

use crate::common::DEFAULT_DSN;

#[tokio::test]
async fn test_commit() {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = Client::new(dsn.to_string());
let conn = client.get_conn().await.unwrap();

conn.exec("CREATE OR REPLACE TABLE t(c int);")
.await
.unwrap();
conn.begin().await.unwrap();
conn.exec("INSERT INTO t VALUES(1);").await.unwrap();
let row = conn.query_row("SELECT * FROM t").await.unwrap();
let row = row.unwrap();
let (val,): (i32,) = row.try_into().unwrap();
assert_eq!(val, 1);
conn.commit().await.unwrap();
let row = conn.query_row("select 1").await.unwrap();
let row = row.unwrap();
println!("{:?}", row);
}

#[tokio::test]
async fn test_rollback() {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = Client::new(dsn.to_string());
let conn = client.get_conn().await.unwrap();

conn.exec("CREATE OR REPLACE TABLE t(c int);")
.await
.unwrap();
conn.begin().await.unwrap();
conn.exec("INSERT INTO t VALUES(1);").await.unwrap();
let row = conn.query_row("SELECT * FROM t").await.unwrap();
let row = row.unwrap();
let (val,): (i32,) = row.try_into().unwrap();
assert_eq!(val, 1);

conn.rollback().await.unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should recheck after rollback?

}
3 changes: 3 additions & 0 deletions sql/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl Value {
impl TryFrom<(&DataType, &str)> for Value {
type Error = Error;

#[allow(deprecated)]
fn try_from((t, v): (&DataType, &str)) -> Result<Self> {
match t {
DataType::Null => Ok(Self::Null),
Expand Down Expand Up @@ -520,6 +521,7 @@ impl_try_from_number_value!(f64);

impl TryFrom<Value> for NaiveDateTime {
type Error = Error;
#[allow(deprecated)]
fn try_from(val: Value) -> Result<Self> {
match val {
Value::Timestamp(i) => {
Expand Down Expand Up @@ -615,6 +617,7 @@ impl std::fmt::Display for Value {
}

// Compatible with Databend, inner values of nested types are quoted.
#[allow(deprecated)]
fn encode_value(f: &mut std::fmt::Formatter<'_>, val: &Value, raw: bool) -> std::fmt::Result {
match val {
Value::Null => write!(f, "NULL"),
Expand Down