mirror of
https://github.com/tursodatabase/libsql.git
synced 2025-07-27 09:34:50 +00:00
Add Connection::transactional_batch (#1366)
* Add Connection::execute_transactional_batch This commit contains only plumbing. There are 3 implementations that need to be provided and they are currently implemented as a `todo!()`. Next commits will fill in those missing implementations. Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com> * Implement execute_transactional_batch for local connection Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com> * Implement execute_transactional_batch for HRANA connection Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com> * Implement execute_transactional_batch for GRPC connection Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com> --------- Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>
This commit is contained in:
committed by
GitHub
parent
372311a008
commit
41e17aa253
@@ -292,6 +292,33 @@ impl Batch {
|
||||
replication_index: None,
|
||||
}
|
||||
}
|
||||
pub fn transactional<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
|
||||
let mut steps = Vec::new();
|
||||
steps.push(BatchStep {
|
||||
condition: None,
|
||||
stmt: Stmt::new("BEGIN TRANSACTION", false),
|
||||
});
|
||||
let mut count = 0u32;
|
||||
for (step, stmt) in stmts.into_iter().enumerate() {
|
||||
count += 1;
|
||||
let condition = Some(BatchCond::Ok { step: step as u32 });
|
||||
steps.push(BatchStep { condition, stmt });
|
||||
}
|
||||
steps.push(BatchStep {
|
||||
condition: Some(BatchCond::Ok { step: count }),
|
||||
stmt: Stmt::new("COMMIT", false),
|
||||
});
|
||||
steps.push(BatchStep {
|
||||
condition: Some(BatchCond::Not {
|
||||
cond: Box::new(BatchCond::Ok { step: count + 1 }),
|
||||
}),
|
||||
stmt: Stmt::new("ROLLBACK", false),
|
||||
});
|
||||
Batch {
|
||||
steps,
|
||||
replication_index: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<Stmt> for Batch {
|
||||
|
@@ -12,6 +12,8 @@ pub(crate) trait Conn {
|
||||
|
||||
async fn execute_batch(&self, sql: &str) -> Result<()>;
|
||||
|
||||
async fn execute_transactional_batch(&self, sql: &str) -> Result<()>;
|
||||
|
||||
async fn prepare(&self, sql: &str) -> Result<Statement>;
|
||||
|
||||
async fn transaction(&self, tx_behavior: TransactionBehavior) -> Result<Transaction>;
|
||||
@@ -57,6 +59,12 @@ impl Connection {
|
||||
self.conn.execute_batch(sql).await
|
||||
}
|
||||
|
||||
/// Execute a batch set of statements atomically in a transaction.
|
||||
pub async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
|
||||
tracing::trace!("executing batch transactional `{}`", sql);
|
||||
self.conn.execute_transactional_batch(sql).await
|
||||
}
|
||||
|
||||
/// Execute sql query provided some type that implements [`IntoParams`] returning
|
||||
/// on success the [`Rows`].
|
||||
///
|
||||
|
@@ -48,6 +48,8 @@ pub enum Error {
|
||||
InvalidParserState(String),
|
||||
#[error("TLS error: {0}")]
|
||||
InvalidTlsConfiguration(std::io::Error),
|
||||
#[error("Transactional batch error: {0}")]
|
||||
TransactionalBatchError(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "hrana")]
|
||||
|
@@ -7,7 +7,7 @@ use crate::hrana::{bind_params, unwrap_err, HranaError, HttpSend, Result};
|
||||
use crate::params::Params;
|
||||
use crate::transaction::Tx;
|
||||
use crate::util::ConnectorService;
|
||||
use crate::{Rows, Statement};
|
||||
use crate::{Error, Rows, Statement};
|
||||
use bytes::Bytes;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::{Stream, TryStreamExt};
|
||||
@@ -121,6 +121,10 @@ impl Conn for HttpConnection<HttpSender> {
|
||||
self.current_stream().execute_batch(sql).await
|
||||
}
|
||||
|
||||
async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
|
||||
self.current_stream().execute_transactional_batch(sql).await
|
||||
}
|
||||
|
||||
async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
|
||||
let stream = self.current_stream().clone();
|
||||
let stmt = crate::hrana::Statement::new(stream, sql.to_string(), true)?;
|
||||
@@ -273,6 +277,23 @@ impl Conn for HranaStream<HttpSender> {
|
||||
unwrap_err(res)
|
||||
}
|
||||
|
||||
async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
|
||||
let mut stmts = Vec::new();
|
||||
let parse = crate::parser::Statement::parse(sql);
|
||||
for s in parse {
|
||||
let s = s?;
|
||||
if s.kind == crate::parser::StmtKind::TxnBegin || s.kind == crate::parser::StmtKind::TxnBeginReadOnly || s.kind == crate::parser::StmtKind::TxnEnd {
|
||||
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
|
||||
}
|
||||
stmts.push(Stmt::new(s.stmt, false));
|
||||
}
|
||||
let res = self
|
||||
.batch_inner(Batch::transactional(stmts), true)
|
||||
.await
|
||||
.map_err(|e| crate::Error::Hrana(e.into()))?;
|
||||
unwrap_err(res)
|
||||
}
|
||||
|
||||
async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
|
||||
let stmt = crate::hrana::Statement::new(self.clone(), sql.to_string(), true)?;
|
||||
Ok(Statement {
|
||||
|
@@ -161,6 +161,62 @@ impl Connection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn execute_transactional_batch_inner<S>(&self, sql: S) -> Result<()>
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
let sql = sql.into();
|
||||
let mut sql = sql.as_str();
|
||||
while !sql.is_empty() {
|
||||
let stmt = self.prepare(sql)?;
|
||||
|
||||
let tail = stmt.tail();
|
||||
let stmt_sql = if tail == 0 || tail >= sql.len() {
|
||||
sql
|
||||
} else {
|
||||
&sql[..tail]
|
||||
};
|
||||
let prefix_count = stmt_sql
|
||||
.chars()
|
||||
.take_while(|c| c.is_whitespace())
|
||||
.count();
|
||||
let stmt_sql = &stmt_sql[prefix_count..];
|
||||
if stmt_sql.starts_with("BEGIN") || stmt_sql.starts_with("COMMIT") || stmt_sql.starts_with("ROLLBACK") || stmt_sql.starts_with("END") {
|
||||
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
|
||||
}
|
||||
|
||||
if !stmt.inner.raw_stmt.is_null() {
|
||||
stmt.step()?;
|
||||
}
|
||||
|
||||
if tail == 0 || tail >= sql.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
sql = &sql[tail..];
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn execute_transactional_batch<S>(&self, sql: S) -> Result<()>
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
self.execute("BEGIN TRANSACTION", Params::None)?;
|
||||
|
||||
match self.execute_transactional_batch_inner(sql) {
|
||||
Ok(_) => {
|
||||
self.execute("COMMIT", Params::None)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
self.execute("ROLLBACK", Params::None)?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the SQL statement synchronously.
|
||||
///
|
||||
/// If you execute a SQL query statement (e.g. `SELECT` statement) that
|
||||
|
@@ -26,6 +26,10 @@ impl Conn for LibsqlConnection {
|
||||
self.conn.execute_batch(sql)
|
||||
}
|
||||
|
||||
async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
|
||||
self.conn.execute_transactional_batch(sql)
|
||||
}
|
||||
|
||||
async fn prepare(&self, sql: &str) -> Result<Statement> {
|
||||
let sql = sql.to_string();
|
||||
|
||||
|
@@ -3,10 +3,7 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use libsql_replication::rpc::proxy::{
|
||||
describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows,
|
||||
State as RemoteState,
|
||||
};
|
||||
use libsql_replication::rpc::proxy::{describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows, State as RemoteState, Step, Query, Cond, OkCond, NotCond, Positional};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::parser;
|
||||
@@ -207,6 +204,34 @@ impl RemoteConnection {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub(self) async fn execute_steps_remote(
|
||||
&self,
|
||||
steps: Vec<Step>,
|
||||
) -> Result<ExecuteResults> {
|
||||
let Some(ref writer) = self.writer else {
|
||||
return Err(Error::Misuse(
|
||||
"Cannot delegate write in local replica mode.".into(),
|
||||
));
|
||||
};
|
||||
let res = writer
|
||||
.execute_steps(steps)
|
||||
.await
|
||||
.map_err(|e| Error::WriteDelegation(e.into()))?;
|
||||
|
||||
{
|
||||
let mut inner = self.inner.lock();
|
||||
inner.state = RemoteState::try_from(res.state)
|
||||
.expect("Invalid state enum")
|
||||
.into();
|
||||
}
|
||||
|
||||
if let Some(replicator) = writer.replicator() {
|
||||
replicator.sync_oneshot().await?;
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub(self) async fn describe(&self, stmt: impl Into<String>) -> Result<DescribeResult> {
|
||||
let Some(ref writer) = self.writer else {
|
||||
return Err(Error::Misuse(
|
||||
@@ -321,6 +346,108 @@ impl Conn for RemoteConnection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
|
||||
let mut stmts = Vec::new();
|
||||
let parse = crate::parser::Statement::parse(sql);
|
||||
for s in parse {
|
||||
let s = s?;
|
||||
if s.kind == StmtKind::TxnBegin || s.kind == StmtKind::TxnBeginReadOnly || s.kind == StmtKind::TxnEnd {
|
||||
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
|
||||
}
|
||||
stmts.push(s);
|
||||
}
|
||||
|
||||
if self.should_execute_local(&stmts[..])? {
|
||||
self.local.execute_transactional_batch(sql).await?;
|
||||
|
||||
if !self.maybe_execute_rollback().await? {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut steps = Vec::with_capacity(stmts.len() + 3);
|
||||
steps.push(Step {
|
||||
query: Some(Query {
|
||||
stmt: "BEGIN TRANSACTION".to_string(),
|
||||
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
let count = stmts.len() as i64;
|
||||
for (idx, stmt) in stmts.into_iter().enumerate() {
|
||||
let step = Step {
|
||||
cond: Some(Cond {
|
||||
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
|
||||
step: idx as i64,
|
||||
..Default::default()
|
||||
})),
|
||||
}),
|
||||
query: Some(Query {
|
||||
stmt: stmt.stmt,
|
||||
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
steps.push(step);
|
||||
}
|
||||
steps.push(Step {
|
||||
cond: Some(Cond {
|
||||
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
|
||||
step: count,
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
}),
|
||||
query: Some(Query {
|
||||
stmt: "COMMIT".to_string(),
|
||||
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
steps.push(Step {
|
||||
cond: Some(Cond {
|
||||
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Not(Box::new(NotCond {
|
||||
cond: Some(Box::new(Cond{
|
||||
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
|
||||
step: count + 1,
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
}))),
|
||||
..Default::default()
|
||||
}),
|
||||
query: Some(Query {
|
||||
stmt: "ROLLBACK".to_string(),
|
||||
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let res = self.execute_steps_remote(steps).await?;
|
||||
|
||||
for result in res.results {
|
||||
match result.row_result {
|
||||
Some(RowResult::Row(row)) => self.update_state(&row),
|
||||
Some(RowResult::Error(e)) => {
|
||||
return Err(Error::RemoteSqliteFailure(
|
||||
e.code,
|
||||
e.extended_code,
|
||||
e.message,
|
||||
))
|
||||
}
|
||||
None => panic!("unexpected empty result row"),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prepare(&self, sql: &str) -> Result<Statement> {
|
||||
let stmt = RemoteStatement::prepare(self.clone(), sql).await?;
|
||||
|
||||
|
@@ -72,6 +72,13 @@ impl Writer {
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.execute_steps(steps).await
|
||||
}
|
||||
|
||||
pub(crate) async fn execute_steps(
|
||||
&self,
|
||||
steps: Vec<Step>,
|
||||
) -> anyhow::Result<ExecuteResults> {
|
||||
self.client
|
||||
.execute_program(ProgramReq {
|
||||
client_id: self.client.client_id(),
|
||||
|
@@ -51,6 +51,145 @@ async fn connection_query() {
|
||||
assert_eq!(row.get::<String>(1).unwrap(), "Alice");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_execute_transactional_batch_success() {
|
||||
let conn = setup().await;
|
||||
|
||||
conn.execute_transactional_batch(
|
||||
"CREATE TABLE foo(x INTEGER);
|
||||
CREATE TABLE bar(y TEXT);",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut rows = conn
|
||||
.query(
|
||||
"SELECT
|
||||
name
|
||||
FROM
|
||||
sqlite_schema
|
||||
WHERE
|
||||
type ='table' AND
|
||||
name NOT LIKE 'sqlite_%';",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "users");
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "foo");
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "bar");
|
||||
|
||||
assert!(rows.next().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_execute_transactional_batch_fail() {
|
||||
let conn = setup().await;
|
||||
|
||||
let res = conn
|
||||
.execute_transactional_batch(
|
||||
"CREATE TABLE unexpected_foo(x INTEGER);
|
||||
CREATE TABLE sqlite_schema(y TEXT);
|
||||
CREATE TABLE unexpected_bar(y TEXT);",
|
||||
)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
let mut rows = conn
|
||||
.query(
|
||||
"SELECT
|
||||
name
|
||||
FROM
|
||||
sqlite_schema
|
||||
WHERE
|
||||
type ='table' AND
|
||||
name NOT LIKE 'sqlite_%';",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "users");
|
||||
|
||||
assert!(rows.next().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_execute_transactional_batch_transaction_fail() {
|
||||
let conn = setup().await;
|
||||
|
||||
let res = conn
|
||||
.execute_transactional_batch(
|
||||
"BEGIN;
|
||||
CREATE TABLE unexpected_foo(x INTEGER);
|
||||
COMMIT;
|
||||
CREATE TABLE sqlite_schema(y TEXT);
|
||||
CREATE TABLE unexpected_bar(y TEXT);",
|
||||
)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
let mut rows = conn
|
||||
.query(
|
||||
"SELECT
|
||||
name
|
||||
FROM
|
||||
sqlite_schema
|
||||
WHERE
|
||||
type ='table' AND
|
||||
name NOT LIKE 'sqlite_%';",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "users");
|
||||
|
||||
assert!(rows.next().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_execute_transactional_batch_transaction_incorrect() {
|
||||
let conn = setup().await;
|
||||
|
||||
let res = conn
|
||||
.execute_transactional_batch(
|
||||
"COMMIT;
|
||||
CREATE TABLE unexpected_foo(x INTEGER);
|
||||
CREATE TABLE sqlite_schema(y TEXT);
|
||||
CREATE TABLE unexpected_bar(y TEXT);",
|
||||
)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
let mut rows = conn
|
||||
.query(
|
||||
"SELECT
|
||||
name
|
||||
FROM
|
||||
sqlite_schema
|
||||
WHERE
|
||||
type ='table' AND
|
||||
name NOT LIKE 'sqlite_%';",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let row = rows.next().await.unwrap().unwrap();
|
||||
assert_eq!(row.get::<String>(0).unwrap(), "users");
|
||||
|
||||
assert!(rows.next().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_execute_batch() {
|
||||
let conn = setup().await;
|
||||
|
Reference in New Issue
Block a user