2023-06-30 10:50:43 +03:00
//go:build cgo
// +build cgo
2023-06-30 09:59:25 +03:00
package libsql
/ *
# cgo CFLAGS : - I . . / c / include
libsql_server,bottomless: add encryption support (#928)
* namespace,replication: add LogFile encryption
Anything that uses our LogFile format can now be encrypted
on-disk.
Tested locally by seeing that `wallog` file contains garbage
and no sensible plaintext strings can be extracted from it.
* test fixups
* libsql-ffi: add libsql_generate_initial_vector and...
... libsql_generate_aes256_key to make them reachable from Rust.
* connection: expose additional encryption symbols
* libsql-server: derive aes256 from user passphrase properly
And by properly, I mean calling back to SQLite3MultipleCiphers' code.
* replication: rename Encryptor to FrameEncryptor
Encryptor sounds a little too generic for this specific use case.
* replication: add snapshot encryption
It uses the same mechanism as wallog encryption, now abstracted
away to libsql-replication crate to be reused.
* replication: add an encryption feature for compilation
* cargo fmt pass
* fix remaining SnapshotFile::open calls in tests
* logger: add an encryption test
* replication: use a single buffer for encryption
Ideally we could even encrypt in place, but WalPage is also
used in snapshots and it's buffered, and that makes it exceptionally
annoying to explain to the borrow checker.
* bottomless: restore with libsql_replication::injector
... instead of the transaction page cache. That gives us free
encryption, since the injector is encryption-aware.
This patch doesn't hook encryption_key parameter yet, it will
come in the next patch.
* bottomless: pass the encryption key in options
For WAL restoration, but also to be able to encrypt data that gets
sent to S3.
* bottomless: inherit encryption key from db config if not specified
* libsql-sys: add db_change_counter()
The helper function calls the underlying C API to extract
4 bytes from offset 24 of the database header and return it.
It's the database change counter, which we can use to compare
two databases and decide which one is newer than the other.
* bottomless: use sqlite API to read database metadata
With encryption enabled, we can no longer just go ahead and read data
from given offsets, we must go through the VFS layer instead.
Fortunately, we can just open a database connection and ask for all
the metadata we need.
* libsql-sys: make db change counter actually read from the db file
* bottomless: treat change counter == 1 as a new database
... which it is, after setting the journal mode. Otherwise we decide
too eagerly that the local database is the source of truth.
* libsql-server: fix a local embedded replica test
rebase conflict with encryption
* bottomless-cli: allow passing the encryption key
* replication: rebase new test to the new api
* snapshots: do not try to decrypt headers
They are not encrypted, so we shouldn't attempt to decrypt the data.
* logger: restore encrypted frames during recovery
Instead of decrypting and encrypting back, we just copy encrypted
frames as is during the recovery process, saves IO.
* compaction: clear unused encryption_key parameter
It wasn't used since for compaction we only need headers,
which are unencrypted.
* replication: switch to FrameBorrowed::new_zeroed
Following MarinPostma's suggestion.
Co-authored-by: Marin Postma <postma.marin@protonmail.com>
* replication: rebase chores, fixing parameters
* libsql-replication: use page_mut() to decrypt data in-place
* rustfmt
* bottomless: use 0 for disabling autocheckpoint
... instead of u32::MAX. Effectively it's similar, but 0 is the correct
choice.
* rustfmt
* libsql-server: make cbc, aes optional for encryption only
* post-rebase fixes
* libsql-replication: suppress warnings when no encryption
* libsql: add encryption support for local databases
* libsql: add bytes dependency for encryption
* libsql-ffi: build libsqlite3mc without debug symbols
Technically it should just depend on cargo build mode,
but that's left for a follow-up.
* bindings: an attempt to compile bindings with releasemode
... partially to save space, but also to make them faster.
---------
Co-authored-by: Marin Postma <postma.marin@protonmail.com>
2024-02-09 15:27:39 +01:00
# cgo LDFLAGS : - L . . / . . / target / release
2023-07-20 12:34:05 +02:00
# cgo LDFLAGS : - lsql_experimental
2023-10-16 10:15:49 -07:00
# cgo LDFLAGS : - L . . / . . / libsql - sqlite3 / . libs
2023-07-20 12:34:05 +02:00
# cgo LDFLAGS : - lsqlite3
2023-06-30 10:50:43 +03:00
# cgo LDFLAGS : - lm
2023-06-30 09:59:25 +03:00
# include < libsql . h >
2023-07-22 14:35:37 +02:00
# include < stdlib . h >
2023-06-30 09:59:25 +03:00
* /
import "C"
import (
2023-07-31 16:58:49 +02:00
"context"
2023-06-30 09:59:25 +03:00
"database/sql"
2023-08-08 15:51:01 +02:00
sqldriver "database/sql/driver"
2023-06-30 10:50:43 +03:00
"fmt"
2024-01-29 15:27:42 +01:00
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/libsql/sqlite-antlr4-parser/sqliteparser"
"github.com/libsql/sqlite-antlr4-parser/sqliteparserutils"
2023-07-31 16:58:49 +02:00
"io"
2024-01-20 11:23:08 +01:00
"net/url"
2024-01-29 15:27:42 +01:00
"regexp"
2024-01-20 11:23:08 +01:00
"strings"
2023-08-09 15:37:25 +02:00
"time"
2023-07-22 14:35:37 +02:00
"unsafe"
2023-06-30 09:59:25 +03:00
)
func init ( ) {
2023-08-08 16:24:00 +02:00
sql . Register ( "libsql" , driver { } )
2023-06-30 09:59:25 +03:00
}
2023-09-02 15:22:15 +02:00
func NewEmbeddedReplicaConnector ( dbPath , primaryUrl , authToken string ) ( * Connector , error ) {
2024-01-20 11:23:08 +01:00
return openEmbeddedReplicaConnector ( dbPath , primaryUrl , authToken , 0 )
2023-08-09 15:37:25 +02:00
}
2023-09-02 15:22:15 +02:00
func NewEmbeddedReplicaConnectorWithAutoSync ( dbPath , primaryUrl , authToken string , syncInterval time . Duration ) ( * Connector , error ) {
2024-01-20 11:23:08 +01:00
return openEmbeddedReplicaConnector ( dbPath , primaryUrl , authToken , syncInterval )
2023-08-09 15:37:25 +02:00
}
2023-08-08 16:24:00 +02:00
type driver struct { }
2024-01-20 11:23:08 +01:00
func ( d driver ) Open ( dbAddress string ) ( sqldriver . Conn , error ) {
connector , err := d . OpenConnector ( dbAddress )
2023-08-08 16:24:00 +02:00
if err != nil {
return nil , err
}
return connector . Connect ( context . Background ( ) )
}
2024-01-20 11:23:08 +01:00
func ( d driver ) OpenConnector ( dbAddress string ) ( sqldriver . Connector , error ) {
if strings . HasPrefix ( dbAddress , ":memory:" ) {
return openLocalConnector ( dbAddress )
}
u , err := url . Parse ( dbAddress )
if err != nil {
return nil , err
}
switch u . Scheme {
case "file" :
return openLocalConnector ( dbAddress )
case "http" :
fallthrough
case "https" :
fallthrough
case "libsql" :
authToken := u . Query ( ) . Get ( "authToken" )
u . RawQuery = ""
return openRemoteConnector ( u . String ( ) , authToken )
}
2024-02-20 16:55:23 +02:00
return nil , fmt . Errorf ( "unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file:, https:// or http://" , u . Scheme )
2023-08-09 15:37:25 +02:00
}
func libsqlSync ( nativeDbPtr C . libsql_database_t ) error {
var errMsg * C . char
statusCode := C . libsql_sync ( nativeDbPtr , & errMsg )
if statusCode != 0 {
return libsqlError ( "failed to sync database " , statusCode , errMsg )
}
return nil
2023-08-08 16:24:00 +02:00
}
2024-01-20 11:23:08 +01:00
func openLocalConnector ( dbPath string ) ( * Connector , error ) {
nativeDbPtr , err := libsqlOpenLocal ( dbPath )
if err != nil {
return nil , err
}
return & Connector { nativeDbPtr : nativeDbPtr } , nil
}
func openRemoteConnector ( primaryUrl , authToken string ) ( * Connector , error ) {
nativeDbPtr , err := libsqlOpenRemote ( primaryUrl , authToken )
if err != nil {
return nil , err
}
return & Connector { nativeDbPtr : nativeDbPtr } , nil
}
func openEmbeddedReplicaConnector ( dbPath , primaryUrl , authToken string , syncInterval time . Duration ) ( * Connector , error ) {
2023-08-09 15:37:25 +02:00
var closeCh chan struct { }
var closeAckCh chan struct { }
2024-01-20 11:23:08 +01:00
nativeDbPtr , err := libsqlOpenWithSync ( dbPath , primaryUrl , authToken )
if err != nil {
return nil , err
}
if err := libsqlSync ( nativeDbPtr ) ; err != nil {
C . libsql_close ( nativeDbPtr )
return nil , err
}
if syncInterval != 0 {
closeCh = make ( chan struct { } , 1 )
closeAckCh = make ( chan struct { } , 1 )
go func ( ) {
for {
timerCh := make ( chan struct { } , 1 )
go func ( ) {
time . Sleep ( syncInterval )
timerCh <- struct { } { }
} ( )
select {
case <- closeCh :
closeAckCh <- struct { } { }
return
case <- timerCh :
if err := libsqlSync ( nativeDbPtr ) ; err != nil {
fmt . Println ( err )
2023-08-09 15:37:25 +02:00
}
}
2024-01-20 11:23:08 +01:00
}
} ( )
2023-08-09 15:37:25 +02:00
}
2023-08-08 16:24:00 +02:00
if err != nil {
return nil , err
}
2023-08-09 15:37:25 +02:00
return & Connector { nativeDbPtr : nativeDbPtr , closeCh : closeCh , closeAckCh : closeAckCh } , nil
2023-07-22 14:35:37 +02:00
}
2023-08-09 15:37:25 +02:00
type Connector struct {
2023-08-08 16:24:00 +02:00
nativeDbPtr C . libsql_database_t
2023-08-09 15:37:25 +02:00
closeCh chan <- struct { }
closeAckCh <- chan struct { }
2023-08-08 16:24:00 +02:00
}
2023-08-09 15:37:25 +02:00
func ( c * Connector ) Sync ( ) error {
return libsqlSync ( c . nativeDbPtr )
}
func ( c * Connector ) Close ( ) error {
if c . closeCh != nil {
c . closeCh <- struct { } { }
<- c . closeAckCh
2024-01-24 11:34:44 +01:00
c . closeCh = nil
c . closeAckCh = nil
2023-08-09 15:37:25 +02:00
}
2024-01-24 11:34:44 +01:00
if c . nativeDbPtr != nil {
C . libsql_close ( c . nativeDbPtr )
}
c . nativeDbPtr = nil
2023-08-08 16:24:00 +02:00
return nil
}
2023-08-09 15:37:25 +02:00
func ( c * Connector ) Connect ( ctx context . Context ) ( sqldriver . Conn , error ) {
2023-08-08 16:24:00 +02:00
nativeConnPtr , err := libsqlConnect ( c . nativeDbPtr )
if err != nil {
return nil , err
}
return & conn { nativePtr : nativeConnPtr } , nil
}
2023-08-09 15:37:25 +02:00
func ( c * Connector ) Driver ( ) sqldriver . Driver {
2023-08-08 16:24:00 +02:00
return driver { }
2023-07-22 14:35:37 +02:00
}
2023-06-30 09:59:25 +03:00
2023-08-04 15:20:27 +02:00
func libsqlError ( message string , statusCode C . int , errMsg * C . char ) error {
code := int ( statusCode )
if errMsg != nil {
msg := C . GoString ( errMsg )
C . libsql_free_string ( errMsg )
return fmt . Errorf ( "%s\nerror code = %d: %v" , message , code , msg )
} else {
return fmt . Errorf ( "%s\nerror code = %d" , message , code )
}
}
2024-01-20 11:23:08 +01:00
func libsqlOpenLocal ( dataSourceName string ) ( C . libsql_database_t , error ) {
2023-06-30 09:59:25 +03:00
connectionString := C . CString ( dataSourceName )
2023-07-22 14:35:37 +02:00
defer C . free ( unsafe . Pointer ( connectionString ) )
2023-08-04 15:20:27 +02:00
var db C . libsql_database_t
var errMsg * C . char
2024-01-20 11:23:08 +01:00
statusCode := C . libsql_open_file ( connectionString , & db , & errMsg )
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprint ( "failed to open local database " , dataSourceName ) , statusCode , errMsg )
}
return db , nil
}
func libsqlOpenRemote ( url , authToken string ) ( C . libsql_database_t , error ) {
connectionString := C . CString ( url )
defer C . free ( unsafe . Pointer ( connectionString ) )
authTokenNativeString := C . CString ( authToken )
defer C . free ( unsafe . Pointer ( authTokenNativeString ) )
var db C . libsql_database_t
var errMsg * C . char
statusCode := C . libsql_open_remote ( connectionString , authTokenNativeString , & db , & errMsg )
2023-08-04 15:20:27 +02:00
if statusCode != 0 {
2024-01-20 11:23:08 +01:00
return nil , libsqlError ( fmt . Sprint ( "failed to open remote database " , url ) , statusCode , errMsg )
2023-07-22 14:35:37 +02:00
}
return db , nil
}
2023-09-02 15:22:15 +02:00
func libsqlOpenWithSync ( dbPath , primaryUrl , authToken string ) ( C . libsql_database_t , error ) {
2023-08-09 15:37:25 +02:00
dbPathNativeString := C . CString ( dbPath )
defer C . free ( unsafe . Pointer ( dbPathNativeString ) )
primaryUrlNativeString := C . CString ( primaryUrl )
defer C . free ( unsafe . Pointer ( primaryUrlNativeString ) )
2023-09-02 15:22:15 +02:00
authTokenNativeString := C . CString ( authToken )
defer C . free ( unsafe . Pointer ( authTokenNativeString ) )
2023-08-09 15:37:25 +02:00
var db C . libsql_database_t
var errMsg * C . char
2023-09-02 15:22:15 +02:00
statusCode := C . libsql_open_sync ( dbPathNativeString , primaryUrlNativeString , authTokenNativeString , & db , & errMsg )
2023-08-09 15:37:25 +02:00
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprintf ( "failed to open database %s %s" , dbPath , primaryUrl ) , statusCode , errMsg )
}
return db , nil
}
2023-07-22 14:35:37 +02:00
func libsqlConnect ( db C . libsql_database_t ) ( C . libsql_connection_t , error ) {
2023-08-05 11:01:30 +02:00
var conn C . libsql_connection_t
var errMsg * C . char
statusCode := C . libsql_connect ( db , & conn , & errMsg )
if statusCode != 0 {
return nil , libsqlError ( "failed to connect to database" , statusCode , errMsg )
2023-07-22 14:35:37 +02:00
}
return conn , nil
}
2023-06-30 09:59:25 +03:00
2023-08-08 16:24:00 +02:00
type conn struct {
nativePtr C . libsql_connection_t
2023-07-22 14:35:37 +02:00
}
2023-08-08 16:24:00 +02:00
func ( c * conn ) Prepare ( query string ) ( sqldriver . Stmt , error ) {
return c . PrepareContext ( context . Background ( ) , query )
2023-06-30 10:50:43 +03:00
}
2023-08-08 16:24:00 +02:00
func ( c * conn ) Begin ( ) ( sqldriver . Tx , error ) {
return c . BeginTx ( context . Background ( ) , sqldriver . TxOptions { } )
2023-06-30 10:50:43 +03:00
}
func ( c * conn ) Close ( ) error {
2023-08-08 16:24:00 +02:00
C . libsql_disconnect ( c . nativePtr )
return nil
2023-06-30 10:50:43 +03:00
}
2024-01-29 15:27:42 +01:00
type ParamsInfo struct {
NamedParameters [ ] string
PositionalParametersCount int
}
func isPositionalParameter ( param string ) ( ok bool , err error ) {
re := regexp . MustCompile ( ` \?([0-9]*).* ` )
match := re . FindSubmatch ( [ ] byte ( param ) )
if match == nil {
return false , nil
}
posS := string ( match [ 1 ] )
if posS == "" {
return true , nil
}
return true , fmt . Errorf ( "unsuppoted positional parameter. This driver does not accept positional parameters with indexes (like ?<number>)" )
}
func removeParamPrefix ( paramName string ) ( string , error ) {
if paramName [ 0 ] == ':' || paramName [ 0 ] == '@' || paramName [ 0 ] == '$' {
return paramName [ 1 : ] , nil
}
return "" , fmt . Errorf ( "all named parameters must start with ':', or '@' or '$'" )
}
func extractParameters ( stmt string ) ( nameParams [ ] string , positionalParamsCount int , err error ) {
statementStream := antlr . NewInputStream ( stmt )
sqliteparser . NewSQLiteLexer ( statementStream )
lexer := sqliteparser . NewSQLiteLexer ( statementStream )
allTokens := lexer . GetAllTokens ( )
nameParamsSet := make ( map [ string ] bool )
for _ , token := range allTokens {
tokenType := token . GetTokenType ( )
if tokenType == sqliteparser . SQLiteLexerBIND_PARAMETER {
parameter := token . GetText ( )
isPositionalParameter , err := isPositionalParameter ( parameter )
if err != nil {
return [ ] string { } , 0 , err
}
if isPositionalParameter {
positionalParamsCount ++
} else {
paramWithoutPrefix , err := removeParamPrefix ( parameter )
if err != nil {
return [ ] string { } , 0 , err
} else {
nameParamsSet [ paramWithoutPrefix ] = true
}
}
}
}
nameParams = make ( [ ] string , 0 , len ( nameParamsSet ) )
for k := range nameParamsSet {
nameParams = append ( nameParams , k )
}
return nameParams , positionalParamsCount , nil
}
func parseStatement ( sql string ) ( [ ] string , [ ] ParamsInfo , error ) {
stmts , _ := sqliteparserutils . SplitStatement ( sql )
stmtsParams := make ( [ ] ParamsInfo , len ( stmts ) )
for idx , stmt := range stmts {
nameParams , positionalParamsCount , err := extractParameters ( stmt )
if err != nil {
return nil , nil , err
}
stmtsParams [ idx ] = ParamsInfo { nameParams , positionalParamsCount }
}
return stmts , stmtsParams , nil
}
2023-08-08 16:24:00 +02:00
func ( c * conn ) PrepareContext ( ctx context . Context , query string ) ( sqldriver . Stmt , error ) {
2024-01-29 15:27:42 +01:00
stmts , paramInfos , err := parseStatement ( query )
if err != nil {
return nil , err
}
if len ( stmts ) != 1 {
return nil , fmt . Errorf ( "only one statement is supported got %d" , len ( stmts ) )
}
numInput := - 1
if len ( paramInfos [ 0 ] . NamedParameters ) == 0 {
numInput = paramInfos [ 0 ] . PositionalParametersCount
}
return & stmt { c , query , numInput } , nil
2023-06-30 10:50:43 +03:00
}
2023-08-08 16:24:00 +02:00
func ( c * conn ) BeginTx ( ctx context . Context , opts sqldriver . TxOptions ) ( sqldriver . Tx , error ) {
2024-01-25 14:44:37 +01:00
if opts . ReadOnly {
return nil , fmt . Errorf ( "read only transactions are not supported" )
}
if opts . Isolation != sqldriver . IsolationLevel ( sql . LevelDefault ) {
return nil , fmt . Errorf ( "isolation level %d is not supported" , opts . Isolation )
}
_ , err := c . ExecContext ( ctx , "BEGIN" , nil )
if err != nil {
return nil , err
}
return & tx { c } , nil
2023-06-30 09:59:25 +03:00
}
2023-07-31 16:58:49 +02:00
2024-01-25 11:05:40 +01:00
func ( c * conn ) executeNoArgs ( query string ) ( C . libsql_rows_t , error ) {
2023-07-31 16:58:49 +02:00
queryCString := C . CString ( query )
defer C . free ( unsafe . Pointer ( queryCString ) )
2023-08-05 11:30:28 +02:00
var rows C . libsql_rows_t
var errMsg * C . char
statusCode := C . libsql_execute ( c . nativePtr , queryCString , & rows , & errMsg )
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprint ( "failed to execute query " , query ) , statusCode , errMsg )
}
return rows , nil
2023-07-31 16:58:49 +02:00
}
2024-01-25 11:05:40 +01:00
func ( c * conn ) execute ( query string , args [ ] sqldriver . NamedValue ) ( C . libsql_rows_t , error ) {
if len ( args ) == 0 {
return c . executeNoArgs ( query )
}
queryCString := C . CString ( query )
defer C . free ( unsafe . Pointer ( queryCString ) )
var stmt C . libsql_stmt_t
var errMsg * C . char
2024-01-29 14:21:15 +01:00
statusCode := C . libsql_prepare ( c . nativePtr , queryCString , & stmt , & errMsg )
2024-01-25 11:05:40 +01:00
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprint ( "failed to prepare query " , query ) , statusCode , errMsg )
}
defer C . libsql_free_stmt ( stmt )
for _ , arg := range args {
var errMsg * C . char
var statusCode C . int
idx := arg . Ordinal
switch arg . Value . ( type ) {
case int64 :
statusCode = C . libsql_bind_int ( stmt , C . int ( idx ) , C . longlong ( arg . Value . ( int64 ) ) , & errMsg )
case float64 :
statusCode = C . libsql_bind_float ( stmt , C . int ( idx ) , C . double ( arg . Value . ( float64 ) ) , & errMsg )
case [ ] byte :
blob := arg . Value . ( [ ] byte )
nativeBlob := C . CBytes ( blob )
statusCode = C . libsql_bind_blob ( stmt , C . int ( idx ) , ( * C . uchar ) ( nativeBlob ) , C . int ( len ( blob ) ) , & errMsg )
C . free ( nativeBlob )
case string :
valueStr := C . CString ( arg . Value . ( string ) )
statusCode = C . libsql_bind_string ( stmt , C . int ( idx ) , valueStr , & errMsg )
C . free ( unsafe . Pointer ( valueStr ) )
case nil :
statusCode = C . libsql_bind_null ( stmt , C . int ( idx ) , & errMsg )
default :
return nil , fmt . Errorf ( "unsupported type %T" , arg . Value )
}
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprintf ( "failed to bind argument no. %d with value %v and type %T" , idx , arg . Value , arg . Value ) , statusCode , errMsg )
}
}
var rows C . libsql_rows_t
2024-01-29 14:21:15 +01:00
statusCode = C . libsql_execute_stmt ( stmt , & rows , & errMsg )
2024-01-25 11:05:40 +01:00
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprint ( "failed to execute query " , query ) , statusCode , errMsg )
}
return rows , nil
}
2024-01-25 14:44:37 +01:00
type execResult struct {
id int64
changes int64
}
func ( r execResult ) LastInsertId ( ) ( int64 , error ) {
return r . id , nil
}
func ( r execResult ) RowsAffected ( ) ( int64 , error ) {
return r . changes , nil
}
2023-08-08 15:51:01 +02:00
func ( c * conn ) ExecContext ( ctx context . Context , query string , args [ ] sqldriver . NamedValue ) ( sqldriver . Result , error ) {
2024-01-25 11:05:40 +01:00
rows , err := c . execute ( query , args )
2023-08-05 11:30:28 +02:00
if err != nil {
return nil , err
}
2024-01-25 14:44:37 +01:00
id := int64 ( C . libsql_last_insert_rowid ( c . nativePtr ) )
changes := int64 ( C . libsql_changes ( c . nativePtr ) )
2023-07-31 16:58:49 +02:00
if rows != nil {
C . libsql_free_rows ( rows )
}
2024-01-25 14:44:37 +01:00
return execResult { id , changes } , nil
}
2024-01-29 15:27:42 +01:00
type stmt struct {
conn * conn
sql string
numInput int
}
func ( s * stmt ) Close ( ) error {
return nil
}
func ( s * stmt ) NumInput ( ) int {
return s . numInput
}
func convertToNamed ( args [ ] sqldriver . Value ) [ ] sqldriver . NamedValue {
if len ( args ) == 0 {
return nil
}
result := make ( [ ] sqldriver . NamedValue , 0 , len ( args ) )
for idx := range args {
result = append ( result , sqldriver . NamedValue { Ordinal : idx , Value : args [ idx ] } )
}
return result
}
func ( s * stmt ) Exec ( args [ ] sqldriver . Value ) ( sqldriver . Result , error ) {
return s . ExecContext ( context . Background ( ) , convertToNamed ( args ) )
}
func ( s * stmt ) Query ( args [ ] sqldriver . Value ) ( sqldriver . Rows , error ) {
return s . QueryContext ( context . Background ( ) , convertToNamed ( args ) )
}
func ( s * stmt ) ExecContext ( ctx context . Context , args [ ] sqldriver . NamedValue ) ( sqldriver . Result , error ) {
return s . conn . ExecContext ( ctx , s . sql , args )
}
func ( s * stmt ) QueryContext ( ctx context . Context , args [ ] sqldriver . NamedValue ) ( sqldriver . Rows , error ) {
return s . conn . QueryContext ( ctx , s . sql , args )
}
2024-01-25 14:44:37 +01:00
type tx struct {
conn * conn
}
func ( t tx ) Commit ( ) error {
_ , err := t . conn . ExecContext ( context . Background ( ) , "COMMIT" , nil )
return err
}
func ( t tx ) Rollback ( ) error {
_ , err := t . conn . ExecContext ( context . Background ( ) , "ROLLBACK" , nil )
return err
2023-07-31 16:58:49 +02:00
}
const (
TYPE_INT int = iota + 1
TYPE_FLOAT
TYPE_TEXT
TYPE_BLOB
TYPE_NULL
)
2023-08-05 11:55:52 +02:00
func newRows ( nativePtr C . libsql_rows_t ) ( * rows , error ) {
2023-08-05 11:30:28 +02:00
if nativePtr == nil {
2024-02-08 16:39:39 +01:00
return & rows { nil , nil } , nil
2023-08-05 11:30:28 +02:00
}
2023-07-31 16:58:49 +02:00
columnCount := int ( C . libsql_column_count ( nativePtr ) )
2024-01-20 11:23:08 +01:00
columns := make ( [ ] string , columnCount )
2023-07-31 16:58:49 +02:00
for i := 0 ; i < columnCount ; i ++ {
2023-08-05 11:55:52 +02:00
var ptr * C . char
var errMsg * C . char
statusCode := C . libsql_column_name ( nativePtr , C . int ( i ) , & ptr , & errMsg )
if statusCode != 0 {
return nil , libsqlError ( fmt . Sprint ( "failed to get column name for index " , i ) , statusCode , errMsg )
}
columns [ i ] = C . GoString ( ptr )
C . libsql_free_string ( ptr )
}
2024-02-08 16:39:39 +01:00
return & rows { nativePtr , columns } , nil
2023-07-31 16:58:49 +02:00
}
type rows struct {
nativePtr C . libsql_rows_t
2023-08-05 11:55:52 +02:00
columnNames [ ] string
2023-07-31 16:58:49 +02:00
}
func ( r * rows ) Columns ( ) [ ] string {
2023-08-05 11:55:52 +02:00
return r . columnNames
2023-07-31 16:58:49 +02:00
}
func ( r * rows ) Close ( ) error {
if r . nativePtr != nil {
C . libsql_free_rows ( r . nativePtr )
r . nativePtr = nil
}
return nil
}
2023-08-08 15:51:01 +02:00
func ( r * rows ) Next ( dest [ ] sqldriver . Value ) error {
2023-07-31 16:58:49 +02:00
if r . nativePtr == nil {
return io . EOF
}
2023-08-05 15:10:14 +02:00
var row C . libsql_row_t
var errMsg * C . char
statusCode := C . libsql_next_row ( r . nativePtr , & row , & errMsg )
if statusCode != 0 {
return libsqlError ( "failed to get next row" , statusCode , errMsg )
}
2023-07-31 16:58:49 +02:00
if row == nil {
r . Close ( )
return io . EOF
}
defer C . libsql_free_row ( row )
count := len ( dest )
2024-01-20 11:23:08 +01:00
if count > len ( r . columnNames ) {
count = len ( r . columnNames )
2023-07-31 16:58:49 +02:00
}
for i := 0 ; i < count ; i ++ {
2024-02-08 16:39:39 +01:00
var columnType C . int
var errMsg * C . char
statusCode := C . libsql_column_type ( r . nativePtr , row , C . int ( i ) , & columnType , & errMsg )
if statusCode != 0 {
return libsqlError ( fmt . Sprint ( "failed to get column type for index " , i ) , statusCode , errMsg )
}
switch int ( columnType ) {
2023-07-31 16:58:49 +02:00
case TYPE_NULL :
dest [ i ] = nil
case TYPE_INT :
2023-08-05 15:36:32 +02:00
var value C . longlong
var errMsg * C . char
statusCode := C . libsql_get_int ( row , C . int ( i ) , & value , & errMsg )
if statusCode != 0 {
return libsqlError ( fmt . Sprint ( "failed to get integer for column " , i ) , statusCode , errMsg )
}
dest [ i ] = int64 ( value )
2023-07-31 16:58:49 +02:00
case TYPE_FLOAT :
2023-08-05 15:42:15 +02:00
var value C . double
var errMsg * C . char
statusCode := C . libsql_get_float ( row , C . int ( i ) , & value , & errMsg )
if statusCode != 0 {
return libsqlError ( fmt . Sprint ( "failed to get float for column " , i ) , statusCode , errMsg )
}
dest [ i ] = float64 ( value )
2023-07-31 16:58:49 +02:00
case TYPE_BLOB :
2023-08-05 15:49:04 +02:00
var nativeBlob C . blob
var errMsg * C . char
statusCode := C . libsql_get_blob ( row , C . int ( i ) , & nativeBlob , & errMsg )
if statusCode != 0 {
return libsqlError ( fmt . Sprint ( "failed to get blob for column " , i ) , statusCode , errMsg )
}
2023-07-31 16:58:49 +02:00
dest [ i ] = C . GoBytes ( unsafe . Pointer ( nativeBlob . ptr ) , C . int ( nativeBlob . len ) )
C . libsql_free_blob ( nativeBlob )
case TYPE_TEXT :
2023-08-05 15:27:53 +02:00
var ptr * C . char
var errMsg * C . char
statusCode := C . libsql_get_string ( row , C . int ( i ) , & ptr , & errMsg )
if statusCode != 0 {
return libsqlError ( fmt . Sprint ( "failed to get string for column " , i ) , statusCode , errMsg )
}
2024-01-30 11:57:01 +01:00
str := C . GoString ( ptr )
2023-07-31 16:58:49 +02:00
C . libsql_free_string ( ptr )
2024-01-30 11:57:01 +01:00
for _ , format := range [ ] string {
"2006-01-02 15:04:05.999999999-07:00" ,
"2006-01-02T15:04:05.999999999-07:00" ,
"2006-01-02 15:04:05.999999999" ,
"2006-01-02T15:04:05.999999999" ,
"2006-01-02 15:04:05" ,
"2006-01-02T15:04:05" ,
"2006-01-02 15:04" ,
"2006-01-02T15:04" ,
"2006-01-02" ,
} {
if t , err := time . ParseInLocation ( format , str , time . UTC ) ; err == nil {
dest [ i ] = t
return nil
}
}
dest [ i ] = str
2023-07-31 16:58:49 +02:00
}
}
return nil
}
2023-08-08 15:51:01 +02:00
func ( c * conn ) QueryContext ( ctx context . Context , query string , args [ ] sqldriver . NamedValue ) ( sqldriver . Rows , error ) {
2024-01-25 11:05:40 +01:00
rowsNativePtr , err := c . execute ( query , args )
2023-08-05 11:30:28 +02:00
if err != nil {
return nil , err
2023-07-31 16:58:49 +02:00
}
2023-08-05 11:55:52 +02:00
return newRows ( rowsNativePtr )
2023-07-31 16:58:49 +02:00
}