0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2025-07-27 09:34:50 +00:00

Add Connection::transactional_batch (#1366)

* Add Connection::execute_transactional_batch

This commit contains only plumbing.
There are 3 implementations that need to be provided
and they are currently implemented as a `todo!()`.
Next commits will fill in those missing implementations.

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for local connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for HRANA connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

* Implement execute_transactional_batch for GRPC connection

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>

---------

Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>
This commit is contained in:
Piotr Jastrzębski
2024-05-05 11:18:22 +02:00
committed by GitHub
parent 372311a008
commit 41e17aa253
9 changed files with 396 additions and 5 deletions

View File

@@ -292,6 +292,33 @@ impl Batch {
replication_index: None,
}
}
pub fn transactional<T: IntoIterator<Item = Stmt>>(stmts: T) -> Self {
let mut steps = Vec::new();
steps.push(BatchStep {
condition: None,
stmt: Stmt::new("BEGIN TRANSACTION", false),
});
let mut count = 0u32;
for (step, stmt) in stmts.into_iter().enumerate() {
count += 1;
let condition = Some(BatchCond::Ok { step: step as u32 });
steps.push(BatchStep { condition, stmt });
}
steps.push(BatchStep {
condition: Some(BatchCond::Ok { step: count }),
stmt: Stmt::new("COMMIT", false),
});
steps.push(BatchStep {
condition: Some(BatchCond::Not {
cond: Box::new(BatchCond::Ok { step: count + 1 }),
}),
stmt: Stmt::new("ROLLBACK", false),
});
Batch {
steps,
replication_index: None,
}
}
}
impl FromIterator<Stmt> for Batch {

View File

@@ -12,6 +12,8 @@ pub(crate) trait Conn {
async fn execute_batch(&self, sql: &str) -> Result<()>;
async fn execute_transactional_batch(&self, sql: &str) -> Result<()>;
async fn prepare(&self, sql: &str) -> Result<Statement>;
async fn transaction(&self, tx_behavior: TransactionBehavior) -> Result<Transaction>;
@@ -57,6 +59,12 @@ impl Connection {
self.conn.execute_batch(sql).await
}
/// Execute a batch set of statements atomically in a transaction.
pub async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
tracing::trace!("executing batch transactional `{}`", sql);
self.conn.execute_transactional_batch(sql).await
}
/// Execute sql query provided some type that implements [`IntoParams`] returning
/// on success the [`Rows`].
///

View File

@@ -48,6 +48,8 @@ pub enum Error {
InvalidParserState(String),
#[error("TLS error: {0}")]
InvalidTlsConfiguration(std::io::Error),
#[error("Transactional batch error: {0}")]
TransactionalBatchError(String),
}
#[cfg(feature = "hrana")]

View File

@@ -7,7 +7,7 @@ use crate::hrana::{bind_params, unwrap_err, HranaError, HttpSend, Result};
use crate::params::Params;
use crate::transaction::Tx;
use crate::util::ConnectorService;
use crate::{Rows, Statement};
use crate::{Error, Rows, Statement};
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::{Stream, TryStreamExt};
@@ -121,6 +121,10 @@ impl Conn for HttpConnection<HttpSender> {
self.current_stream().execute_batch(sql).await
}
async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
self.current_stream().execute_transactional_batch(sql).await
}
async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
let stream = self.current_stream().clone();
let stmt = crate::hrana::Statement::new(stream, sql.to_string(), true)?;
@@ -273,6 +277,23 @@ impl Conn for HranaStream<HttpSender> {
unwrap_err(res)
}
async fn execute_transactional_batch(&self, sql: &str) -> crate::Result<()> {
let mut stmts = Vec::new();
let parse = crate::parser::Statement::parse(sql);
for s in parse {
let s = s?;
if s.kind == crate::parser::StmtKind::TxnBegin || s.kind == crate::parser::StmtKind::TxnBeginReadOnly || s.kind == crate::parser::StmtKind::TxnEnd {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}
stmts.push(Stmt::new(s.stmt, false));
}
let res = self
.batch_inner(Batch::transactional(stmts), true)
.await
.map_err(|e| crate::Error::Hrana(e.into()))?;
unwrap_err(res)
}
async fn prepare(&self, sql: &str) -> crate::Result<Statement> {
let stmt = crate::hrana::Statement::new(self.clone(), sql.to_string(), true)?;
Ok(Statement {

View File

@@ -161,6 +161,62 @@ impl Connection {
Ok(())
}
fn execute_transactional_batch_inner<S>(&self, sql: S) -> Result<()>
where
S: Into<String>,
{
let sql = sql.into();
let mut sql = sql.as_str();
while !sql.is_empty() {
let stmt = self.prepare(sql)?;
let tail = stmt.tail();
let stmt_sql = if tail == 0 || tail >= sql.len() {
sql
} else {
&sql[..tail]
};
let prefix_count = stmt_sql
.chars()
.take_while(|c| c.is_whitespace())
.count();
let stmt_sql = &stmt_sql[prefix_count..];
if stmt_sql.starts_with("BEGIN") || stmt_sql.starts_with("COMMIT") || stmt_sql.starts_with("ROLLBACK") || stmt_sql.starts_with("END") {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}
if !stmt.inner.raw_stmt.is_null() {
stmt.step()?;
}
if tail == 0 || tail >= sql.len() {
break;
}
sql = &sql[tail..];
}
Ok(())
}
pub fn execute_transactional_batch<S>(&self, sql: S) -> Result<()>
where
S: Into<String>,
{
self.execute("BEGIN TRANSACTION", Params::None)?;
match self.execute_transactional_batch_inner(sql) {
Ok(_) => {
self.execute("COMMIT", Params::None)?;
Ok(())
}
Err(e) => {
self.execute("ROLLBACK", Params::None)?;
Err(e)
}
}
}
/// Execute the SQL statement synchronously.
///
/// If you execute a SQL query statement (e.g. `SELECT` statement) that

View File

@@ -26,6 +26,10 @@ impl Conn for LibsqlConnection {
self.conn.execute_batch(sql)
}
async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
self.conn.execute_transactional_batch(sql)
}
async fn prepare(&self, sql: &str) -> Result<Statement> {
let sql = sql.to_string();

View File

@@ -3,10 +3,7 @@
use std::str::FromStr;
use std::sync::Arc;
use libsql_replication::rpc::proxy::{
describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows,
State as RemoteState,
};
use libsql_replication::rpc::proxy::{describe_result, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows, State as RemoteState, Step, Query, Cond, OkCond, NotCond, Positional};
use parking_lot::Mutex;
use crate::parser;
@@ -207,6 +204,34 @@ impl RemoteConnection {
Ok(res)
}
pub(self) async fn execute_steps_remote(
&self,
steps: Vec<Step>,
) -> Result<ExecuteResults> {
let Some(ref writer) = self.writer else {
return Err(Error::Misuse(
"Cannot delegate write in local replica mode.".into(),
));
};
let res = writer
.execute_steps(steps)
.await
.map_err(|e| Error::WriteDelegation(e.into()))?;
{
let mut inner = self.inner.lock();
inner.state = RemoteState::try_from(res.state)
.expect("Invalid state enum")
.into();
}
if let Some(replicator) = writer.replicator() {
replicator.sync_oneshot().await?;
}
Ok(res)
}
pub(self) async fn describe(&self, stmt: impl Into<String>) -> Result<DescribeResult> {
let Some(ref writer) = self.writer else {
return Err(Error::Misuse(
@@ -321,6 +346,108 @@ impl Conn for RemoteConnection {
Ok(())
}
async fn execute_transactional_batch(&self, sql: &str) -> Result<()> {
let mut stmts = Vec::new();
let parse = crate::parser::Statement::parse(sql);
for s in parse {
let s = s?;
if s.kind == StmtKind::TxnBegin || s.kind == StmtKind::TxnBeginReadOnly || s.kind == StmtKind::TxnEnd {
return Err(Error::TransactionalBatchError("Transactions forbidden inside transactional batch".to_string()));
}
stmts.push(s);
}
if self.should_execute_local(&stmts[..])? {
self.local.execute_transactional_batch(sql).await?;
if !self.maybe_execute_rollback().await? {
return Ok(());
}
}
let mut steps = Vec::with_capacity(stmts.len() + 3);
steps.push(Step {
query: Some(Query {
stmt: "BEGIN TRANSACTION".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});
let count = stmts.len() as i64;
for (idx, stmt) in stmts.into_iter().enumerate() {
let step = Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: idx as i64,
..Default::default()
})),
}),
query: Some(Query {
stmt: stmt.stmt,
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
};
steps.push(step);
}
steps.push(Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: count,
..Default::default()
})),
..Default::default()
}),
query: Some(Query {
stmt: "COMMIT".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});
steps.push(Step {
cond: Some(Cond {
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Not(Box::new(NotCond {
cond: Some(Box::new(Cond{
cond: Some(libsql_replication::rpc::proxy::cond::Cond::Ok(OkCond {
step: count + 1,
..Default::default()
})),
..Default::default()
})),
..Default::default()
}))),
..Default::default()
}),
query: Some(Query {
stmt: "ROLLBACK".to_string(),
params: Some(libsql_replication::rpc::proxy::query::Params::Positional(Positional::default())),
..Default::default()
}),
..Default::default()
});
let res = self.execute_steps_remote(steps).await?;
for result in res.results {
match result.row_result {
Some(RowResult::Row(row)) => self.update_state(&row),
Some(RowResult::Error(e)) => {
return Err(Error::RemoteSqliteFailure(
e.code,
e.extended_code,
e.message,
))
}
None => panic!("unexpected empty result row"),
};
}
Ok(())
}
async fn prepare(&self, sql: &str) -> Result<Statement> {
let stmt = RemoteStatement::prepare(self.clone(), sql).await?;

View File

@@ -72,6 +72,13 @@ impl Writer {
})
.collect();
self.execute_steps(steps).await
}
pub(crate) async fn execute_steps(
&self,
steps: Vec<Step>,
) -> anyhow::Result<ExecuteResults> {
self.client
.execute_program(ProgramReq {
client_id: self.client.client_id(),

View File

@@ -51,6 +51,145 @@ async fn connection_query() {
assert_eq!(row.get::<String>(1).unwrap(), "Alice");
}
#[tokio::test]
async fn connection_execute_transactional_batch_success() {
let conn = setup().await;
conn.execute_transactional_batch(
"CREATE TABLE foo(x INTEGER);
CREATE TABLE bar(y TEXT);",
)
.await
.unwrap();
let mut rows = conn
.query(
"SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';",
(),
)
.await
.unwrap();
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "users");
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "foo");
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "bar");
assert!(rows.next().await.unwrap().is_none());
}
#[tokio::test]
async fn connection_execute_transactional_batch_fail() {
let conn = setup().await;
let res = conn
.execute_transactional_batch(
"CREATE TABLE unexpected_foo(x INTEGER);
CREATE TABLE sqlite_schema(y TEXT);
CREATE TABLE unexpected_bar(y TEXT);",
)
.await;
assert!(res.is_err());
let mut rows = conn
.query(
"SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';",
(),
)
.await
.unwrap();
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "users");
assert!(rows.next().await.unwrap().is_none());
}
#[tokio::test]
async fn connection_execute_transactional_batch_transaction_fail() {
let conn = setup().await;
let res = conn
.execute_transactional_batch(
"BEGIN;
CREATE TABLE unexpected_foo(x INTEGER);
COMMIT;
CREATE TABLE sqlite_schema(y TEXT);
CREATE TABLE unexpected_bar(y TEXT);",
)
.await;
assert!(res.is_err());
let mut rows = conn
.query(
"SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';",
(),
)
.await
.unwrap();
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "users");
assert!(rows.next().await.unwrap().is_none());
}
#[tokio::test]
async fn connection_execute_transactional_batch_transaction_incorrect() {
let conn = setup().await;
let res = conn
.execute_transactional_batch(
"COMMIT;
CREATE TABLE unexpected_foo(x INTEGER);
CREATE TABLE sqlite_schema(y TEXT);
CREATE TABLE unexpected_bar(y TEXT);",
)
.await;
assert!(res.is_err());
let mut rows = conn
.query(
"SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';",
(),
)
.await
.unwrap();
let row = rows.next().await.unwrap().unwrap();
assert_eq!(row.get::<String>(0).unwrap(), "users");
assert!(rows.next().await.unwrap().is_none());
}
#[tokio::test]
async fn connection_execute_batch() {
let conn = setup().await;