2023-06-30 10:50:43 +03:00
|
|
|
//go:build cgo
|
|
|
|
// +build cgo
|
|
|
|
|
2023-06-30 09:59:25 +03:00
|
|
|
package libsql
|
|
|
|
|
|
|
|
/*
|
|
|
|
#cgo CFLAGS: -I../c/include
|
libsql_server,bottomless: add encryption support (#928)
* namespace,replication: add LogFile encryption
Anything that uses our LogFile format can now be encrypted
on-disk.
Tested locally by seeing that `wallog` file contains garbage
and no sensible plaintext strings can be extracted from it.
* test fixups
* libsql-ffi: add libsql_generate_initial_vector and...
... libsql_generate_aes256_key to make them reachable from Rust.
* connection: expose additional encryption symbols
* libsql-server: derive aes256 from user passphrase properly
And by properly, I mean calling back to SQLite3MultipleCiphers' code.
* replication: rename Encryptor to FrameEncryptor
Encryptor sounds a little too generic for this specific use case.
* replication: add snapshot encryption
It uses the same mechanism as wallog encryption, now abstracted
away to libsql-replication crate to be reused.
* replication: add an encryption feature for compilation
* cargo fmt pass
* fix remaining SnapshotFile::open calls in tests
* logger: add an encryption test
* replication: use a single buffer for encryption
Ideally we could even encrypt in place, but WalPage is also
used in snapshots and it's buffered, and that makes it exceptionally
annoying to explain to the borrow checker.
* bottomless: restore with libsql_replication::injector
... instead of the transaction page cache. That gives us free
encryption, since the injector is encryption-aware.
This patch doesn't hook encryption_key parameter yet, it will
come in the next patch.
* bottomless: pass the encryption key in options
For WAL restoration, but also to be able to encrypt data that gets
sent to S3.
* bottomless: inherit encryption key from db config if not specified
* libsql-sys: add db_change_counter()
The helper function calls the underlying C API to extract
4 bytes from offset 24 of the database header and return it.
It's the database change counter, which we can use to compare
two databases and decide which one is newer than the other.
* bottomless: use sqlite API to read database metadata
With encryption enabled, we can no longer just go ahead and read data
from given offsets, we must go through the VFS layer instead.
Fortunately, we can just open a database connection and ask for all
the metadata we need.
* libsql-sys: make db change counter actually read from the db file
* bottomless: treat change counter == 1 as a new database
... which it is, after setting the journal mode. Otherwise we decide
too eagerly that the local database is the source of truth.
* libsql-server: fix a local embedded replica test
rebase conflict with encryption
* bottomless-cli: allow passing the encryption key
* replication: rebase new test to the new api
* snapshots: do not try to decrypt headers
They are not encrypted, so we shouldn't attempt to decrypt the data.
* logger: restore encrypted frames during recovery
Instead of decrypting and encrypting back, we just copy encrypted
frames as is during the recovery process, saves IO.
* compaction: clear unused encryption_key parameter
It wasn't used since for compaction we only need headers,
which are unencrypted.
* replication: switch to FrameBorrowed::new_zeroed
Following MarinPostma's suggestion.
Co-authored-by: Marin Postma <postma.marin@protonmail.com>
* replication: rebase chores, fixing parameters
* libsql-replication: use page_mut() to decrypt data in-place
* rustfmt
* bottomless: use 0 for disabling autocheckpoint
... instead of u32::MAX. Effectively it's similar, but 0 is the correct
choice.
* rustfmt
* libsql-server: make cbc, aes optional for encryption only
* post-rebase fixes
* libsql-replication: suppress warnings when no encryption
* libsql: add encryption support for local databases
* libsql: add bytes dependency for encryption
* libsql-ffi: build libsqlite3mc without debug symbols
Technically it should just depend on cargo build mode,
but that's left for a follow-up.
* bindings: an attempt to compile bindings with releasemode
... partially to save space, but also to make them faster.
---------
Co-authored-by: Marin Postma <postma.marin@protonmail.com>
2024-02-09 15:27:39 +01:00
|
|
|
#cgo LDFLAGS: -L../../target/release
|
2023-07-20 12:34:05 +02:00
|
|
|
#cgo LDFLAGS: -lsql_experimental
|
2023-10-16 10:15:49 -07:00
|
|
|
#cgo LDFLAGS: -L../../libsql-sqlite3/.libs
|
2023-07-20 12:34:05 +02:00
|
|
|
#cgo LDFLAGS: -lsqlite3
|
2023-06-30 10:50:43 +03:00
|
|
|
#cgo LDFLAGS: -lm
|
2023-06-30 09:59:25 +03:00
|
|
|
#include <libsql.h>
|
2023-07-22 14:35:37 +02:00
|
|
|
#include <stdlib.h>
|
2023-06-30 09:59:25 +03:00
|
|
|
*/
|
|
|
|
import "C"
|
|
|
|
|
|
|
|
import (
|
2023-07-31 16:58:49 +02:00
|
|
|
"context"
|
2023-06-30 09:59:25 +03:00
|
|
|
"database/sql"
|
2023-08-08 15:51:01 +02:00
|
|
|
sqldriver "database/sql/driver"
|
2023-06-30 10:50:43 +03:00
|
|
|
"fmt"
|
2024-01-29 15:27:42 +01:00
|
|
|
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
|
|
|
|
"github.com/libsql/sqlite-antlr4-parser/sqliteparser"
|
|
|
|
"github.com/libsql/sqlite-antlr4-parser/sqliteparserutils"
|
2023-07-31 16:58:49 +02:00
|
|
|
"io"
|
2024-01-20 11:23:08 +01:00
|
|
|
"net/url"
|
2024-01-29 15:27:42 +01:00
|
|
|
"regexp"
|
2024-01-20 11:23:08 +01:00
|
|
|
"strings"
|
2023-08-09 15:37:25 +02:00
|
|
|
"time"
|
2023-07-22 14:35:37 +02:00
|
|
|
"unsafe"
|
2023-06-30 09:59:25 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
func init() {
|
2023-08-08 16:24:00 +02:00
|
|
|
sql.Register("libsql", driver{})
|
2023-06-30 09:59:25 +03:00
|
|
|
}
|
|
|
|
|
2023-09-02 15:22:15 +02:00
|
|
|
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
|
2024-01-20 11:23:08 +01:00
|
|
|
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, 0)
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
|
|
|
|
2023-09-02 15:22:15 +02:00
|
|
|
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
|
2024-01-20 11:23:08 +01:00
|
|
|
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, syncInterval)
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
type driver struct{}
|
|
|
|
|
2024-01-20 11:23:08 +01:00
|
|
|
func (d driver) Open(dbAddress string) (sqldriver.Conn, error) {
|
|
|
|
connector, err := d.OpenConnector(dbAddress)
|
2023-08-08 16:24:00 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return connector.Connect(context.Background())
|
|
|
|
}
|
|
|
|
|
2024-01-20 11:23:08 +01:00
|
|
|
func (d driver) OpenConnector(dbAddress string) (sqldriver.Connector, error) {
|
|
|
|
if strings.HasPrefix(dbAddress, ":memory:") {
|
|
|
|
return openLocalConnector(dbAddress)
|
|
|
|
}
|
|
|
|
u, err := url.Parse(dbAddress)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
switch u.Scheme {
|
|
|
|
case "file":
|
|
|
|
return openLocalConnector(dbAddress)
|
|
|
|
case "http":
|
|
|
|
fallthrough
|
|
|
|
case "https":
|
|
|
|
fallthrough
|
|
|
|
case "libsql":
|
|
|
|
authToken := u.Query().Get("authToken")
|
|
|
|
u.RawQuery = ""
|
|
|
|
return openRemoteConnector(u.String(), authToken)
|
|
|
|
}
|
2024-02-20 16:55:23 +02:00
|
|
|
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file:, https:// or http://", u.Scheme)
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func libsqlSync(nativeDbPtr C.libsql_database_t) error {
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_sync(nativeDbPtr, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError("failed to sync database ", statusCode, errMsg)
|
|
|
|
}
|
|
|
|
return nil
|
2023-08-08 16:24:00 +02:00
|
|
|
}
|
|
|
|
|
2024-01-20 11:23:08 +01:00
|
|
|
func openLocalConnector(dbPath string) (*Connector, error) {
|
|
|
|
nativeDbPtr, err := libsqlOpenLocal(dbPath)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &Connector{nativeDbPtr: nativeDbPtr}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
|
|
|
|
nativeDbPtr, err := libsqlOpenRemote(primaryUrl, authToken)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &Connector{nativeDbPtr: nativeDbPtr}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
|
2023-08-09 15:37:25 +02:00
|
|
|
var closeCh chan struct{}
|
|
|
|
var closeAckCh chan struct{}
|
2024-01-20 11:23:08 +01:00
|
|
|
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if err := libsqlSync(nativeDbPtr); err != nil {
|
|
|
|
C.libsql_close(nativeDbPtr)
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if syncInterval != 0 {
|
|
|
|
closeCh = make(chan struct{}, 1)
|
|
|
|
closeAckCh = make(chan struct{}, 1)
|
|
|
|
go func() {
|
|
|
|
for {
|
|
|
|
timerCh := make(chan struct{}, 1)
|
|
|
|
go func() {
|
|
|
|
time.Sleep(syncInterval)
|
|
|
|
timerCh <- struct{}{}
|
|
|
|
}()
|
|
|
|
select {
|
|
|
|
case <-closeCh:
|
|
|
|
closeAckCh <- struct{}{}
|
|
|
|
return
|
|
|
|
case <-timerCh:
|
|
|
|
if err := libsqlSync(nativeDbPtr); err != nil {
|
|
|
|
fmt.Println(err)
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
|
|
|
}
|
2024-01-20 11:23:08 +01:00
|
|
|
}
|
|
|
|
}()
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
2023-08-08 16:24:00 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-08-09 15:37:25 +02:00
|
|
|
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
|
2023-07-22 14:35:37 +02:00
|
|
|
}
|
|
|
|
|
2023-08-09 15:37:25 +02:00
|
|
|
type Connector struct {
|
2023-08-08 16:24:00 +02:00
|
|
|
nativeDbPtr C.libsql_database_t
|
2023-08-09 15:37:25 +02:00
|
|
|
closeCh chan<- struct{}
|
|
|
|
closeAckCh <-chan struct{}
|
2023-08-08 16:24:00 +02:00
|
|
|
}
|
|
|
|
|
2023-08-09 15:37:25 +02:00
|
|
|
func (c *Connector) Sync() error {
|
|
|
|
return libsqlSync(c.nativeDbPtr)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Connector) Close() error {
|
|
|
|
if c.closeCh != nil {
|
|
|
|
c.closeCh <- struct{}{}
|
|
|
|
<-c.closeAckCh
|
2024-01-24 11:34:44 +01:00
|
|
|
c.closeCh = nil
|
|
|
|
c.closeAckCh = nil
|
2023-08-09 15:37:25 +02:00
|
|
|
}
|
2024-01-24 11:34:44 +01:00
|
|
|
if c.nativeDbPtr != nil {
|
|
|
|
C.libsql_close(c.nativeDbPtr)
|
|
|
|
}
|
|
|
|
c.nativeDbPtr = nil
|
2023-08-08 16:24:00 +02:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-08-09 15:37:25 +02:00
|
|
|
func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
|
2023-08-08 16:24:00 +02:00
|
|
|
nativeConnPtr, err := libsqlConnect(c.nativeDbPtr)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &conn{nativePtr: nativeConnPtr}, nil
|
|
|
|
}
|
|
|
|
|
2023-08-09 15:37:25 +02:00
|
|
|
func (c *Connector) Driver() sqldriver.Driver {
|
2023-08-08 16:24:00 +02:00
|
|
|
return driver{}
|
2023-07-22 14:35:37 +02:00
|
|
|
}
|
2023-06-30 09:59:25 +03:00
|
|
|
|
2023-08-04 15:20:27 +02:00
|
|
|
func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
|
|
|
|
code := int(statusCode)
|
|
|
|
if errMsg != nil {
|
|
|
|
msg := C.GoString(errMsg)
|
|
|
|
C.libsql_free_string(errMsg)
|
|
|
|
return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg)
|
|
|
|
} else {
|
|
|
|
return fmt.Errorf("%s\nerror code = %d", message, code)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-20 11:23:08 +01:00
|
|
|
func libsqlOpenLocal(dataSourceName string) (C.libsql_database_t, error) {
|
2023-06-30 09:59:25 +03:00
|
|
|
connectionString := C.CString(dataSourceName)
|
2023-07-22 14:35:37 +02:00
|
|
|
defer C.free(unsafe.Pointer(connectionString))
|
|
|
|
|
2023-08-04 15:20:27 +02:00
|
|
|
var db C.libsql_database_t
|
|
|
|
var errMsg *C.char
|
2024-01-20 11:23:08 +01:00
|
|
|
statusCode := C.libsql_open_file(connectionString, &db, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprint("failed to open local database ", dataSourceName), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
return db, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func libsqlOpenRemote(url, authToken string) (C.libsql_database_t, error) {
|
|
|
|
connectionString := C.CString(url)
|
|
|
|
defer C.free(unsafe.Pointer(connectionString))
|
|
|
|
authTokenNativeString := C.CString(authToken)
|
|
|
|
defer C.free(unsafe.Pointer(authTokenNativeString))
|
|
|
|
|
|
|
|
var db C.libsql_database_t
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_open_remote(connectionString, authTokenNativeString, &db, &errMsg)
|
2023-08-04 15:20:27 +02:00
|
|
|
if statusCode != 0 {
|
2024-01-20 11:23:08 +01:00
|
|
|
return nil, libsqlError(fmt.Sprint("failed to open remote database ", url), statusCode, errMsg)
|
2023-07-22 14:35:37 +02:00
|
|
|
}
|
|
|
|
return db, nil
|
|
|
|
}
|
|
|
|
|
2023-09-02 15:22:15 +02:00
|
|
|
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
|
2023-08-09 15:37:25 +02:00
|
|
|
dbPathNativeString := C.CString(dbPath)
|
|
|
|
defer C.free(unsafe.Pointer(dbPathNativeString))
|
|
|
|
primaryUrlNativeString := C.CString(primaryUrl)
|
|
|
|
defer C.free(unsafe.Pointer(primaryUrlNativeString))
|
2023-09-02 15:22:15 +02:00
|
|
|
authTokenNativeString := C.CString(authToken)
|
|
|
|
defer C.free(unsafe.Pointer(authTokenNativeString))
|
2023-08-09 15:37:25 +02:00
|
|
|
|
|
|
|
var db C.libsql_database_t
|
|
|
|
var errMsg *C.char
|
2023-09-02 15:22:15 +02:00
|
|
|
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
|
2023-08-09 15:37:25 +02:00
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
return db, nil
|
|
|
|
}
|
|
|
|
|
2023-07-22 14:35:37 +02:00
|
|
|
func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) {
|
2023-08-05 11:01:30 +02:00
|
|
|
var conn C.libsql_connection_t
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_connect(db, &conn, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError("failed to connect to database", statusCode, errMsg)
|
2023-07-22 14:35:37 +02:00
|
|
|
}
|
|
|
|
return conn, nil
|
|
|
|
}
|
2023-06-30 09:59:25 +03:00
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
type conn struct {
|
|
|
|
nativePtr C.libsql_connection_t
|
2023-07-22 14:35:37 +02:00
|
|
|
}
|
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
|
|
|
|
return c.PrepareContext(context.Background(), query)
|
2023-06-30 10:50:43 +03:00
|
|
|
}
|
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
func (c *conn) Begin() (sqldriver.Tx, error) {
|
|
|
|
return c.BeginTx(context.Background(), sqldriver.TxOptions{})
|
2023-06-30 10:50:43 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *conn) Close() error {
|
2023-08-08 16:24:00 +02:00
|
|
|
C.libsql_disconnect(c.nativePtr)
|
|
|
|
return nil
|
2023-06-30 10:50:43 +03:00
|
|
|
}
|
|
|
|
|
2024-01-29 15:27:42 +01:00
|
|
|
type ParamsInfo struct {
|
|
|
|
NamedParameters []string
|
|
|
|
PositionalParametersCount int
|
|
|
|
}
|
|
|
|
|
|
|
|
func isPositionalParameter(param string) (ok bool, err error) {
|
|
|
|
re := regexp.MustCompile(`\?([0-9]*).*`)
|
|
|
|
match := re.FindSubmatch([]byte(param))
|
|
|
|
if match == nil {
|
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
posS := string(match[1])
|
|
|
|
if posS == "" {
|
|
|
|
return true, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return true, fmt.Errorf("unsuppoted positional parameter. This driver does not accept positional parameters with indexes (like ?<number>)")
|
|
|
|
}
|
|
|
|
|
|
|
|
func removeParamPrefix(paramName string) (string, error) {
|
|
|
|
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
|
|
|
|
return paramName[1:], nil
|
|
|
|
}
|
|
|
|
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
|
|
|
|
}
|
|
|
|
|
|
|
|
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
|
|
|
|
statementStream := antlr.NewInputStream(stmt)
|
|
|
|
sqliteparser.NewSQLiteLexer(statementStream)
|
|
|
|
lexer := sqliteparser.NewSQLiteLexer(statementStream)
|
|
|
|
|
|
|
|
allTokens := lexer.GetAllTokens()
|
|
|
|
|
|
|
|
nameParamsSet := make(map[string]bool)
|
|
|
|
|
|
|
|
for _, token := range allTokens {
|
|
|
|
tokenType := token.GetTokenType()
|
|
|
|
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
|
|
|
|
parameter := token.GetText()
|
|
|
|
|
|
|
|
isPositionalParameter, err := isPositionalParameter(parameter)
|
|
|
|
if err != nil {
|
|
|
|
return []string{}, 0, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if isPositionalParameter {
|
|
|
|
positionalParamsCount++
|
|
|
|
} else {
|
|
|
|
paramWithoutPrefix, err := removeParamPrefix(parameter)
|
|
|
|
if err != nil {
|
|
|
|
return []string{}, 0, err
|
|
|
|
} else {
|
|
|
|
nameParamsSet[paramWithoutPrefix] = true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
nameParams = make([]string, 0, len(nameParamsSet))
|
|
|
|
for k := range nameParamsSet {
|
|
|
|
nameParams = append(nameParams, k)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nameParams, positionalParamsCount, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func parseStatement(sql string) ([]string, []ParamsInfo, error) {
|
|
|
|
stmts, _ := sqliteparserutils.SplitStatement(sql)
|
|
|
|
|
|
|
|
stmtsParams := make([]ParamsInfo, len(stmts))
|
|
|
|
for idx, stmt := range stmts {
|
|
|
|
nameParams, positionalParamsCount, err := extractParameters(stmt)
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, err
|
|
|
|
}
|
|
|
|
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
|
|
|
|
}
|
|
|
|
return stmts, stmtsParams, nil
|
|
|
|
}
|
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
|
2024-01-29 15:27:42 +01:00
|
|
|
stmts, paramInfos, err := parseStatement(query)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if len(stmts) != 1 {
|
|
|
|
return nil, fmt.Errorf("only one statement is supported got %d", len(stmts))
|
|
|
|
}
|
|
|
|
numInput := -1
|
|
|
|
if len(paramInfos[0].NamedParameters) == 0 {
|
|
|
|
numInput = paramInfos[0].PositionalParametersCount
|
|
|
|
}
|
|
|
|
return &stmt{c, query, numInput}, nil
|
2023-06-30 10:50:43 +03:00
|
|
|
}
|
|
|
|
|
2023-08-08 16:24:00 +02:00
|
|
|
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
|
2024-01-25 14:44:37 +01:00
|
|
|
if opts.ReadOnly {
|
|
|
|
return nil, fmt.Errorf("read only transactions are not supported")
|
|
|
|
}
|
|
|
|
if opts.Isolation != sqldriver.IsolationLevel(sql.LevelDefault) {
|
|
|
|
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
|
|
|
|
}
|
|
|
|
_, err := c.ExecContext(ctx, "BEGIN", nil)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &tx{c}, nil
|
2023-06-30 09:59:25 +03:00
|
|
|
}
|
2023-07-31 16:58:49 +02:00
|
|
|
|
2024-01-25 11:05:40 +01:00
|
|
|
func (c *conn) executeNoArgs(query string) (C.libsql_rows_t, error) {
|
2023-07-31 16:58:49 +02:00
|
|
|
queryCString := C.CString(query)
|
|
|
|
defer C.free(unsafe.Pointer(queryCString))
|
|
|
|
|
2023-08-05 11:30:28 +02:00
|
|
|
var rows C.libsql_rows_t
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
return rows, nil
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
|
2024-01-25 11:05:40 +01:00
|
|
|
func (c *conn) execute(query string, args []sqldriver.NamedValue) (C.libsql_rows_t, error) {
|
|
|
|
if len(args) == 0 {
|
|
|
|
return c.executeNoArgs(query)
|
|
|
|
}
|
|
|
|
queryCString := C.CString(query)
|
|
|
|
defer C.free(unsafe.Pointer(queryCString))
|
|
|
|
|
|
|
|
var stmt C.libsql_stmt_t
|
|
|
|
var errMsg *C.char
|
2024-01-29 14:21:15 +01:00
|
|
|
statusCode := C.libsql_prepare(c.nativePtr, queryCString, &stmt, &errMsg)
|
2024-01-25 11:05:40 +01:00
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprint("failed to prepare query ", query), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
defer C.libsql_free_stmt(stmt)
|
|
|
|
|
|
|
|
for _, arg := range args {
|
|
|
|
var errMsg *C.char
|
|
|
|
var statusCode C.int
|
|
|
|
idx := arg.Ordinal
|
|
|
|
switch arg.Value.(type) {
|
|
|
|
case int64:
|
|
|
|
statusCode = C.libsql_bind_int(stmt, C.int(idx), C.longlong(arg.Value.(int64)), &errMsg)
|
|
|
|
case float64:
|
|
|
|
statusCode = C.libsql_bind_float(stmt, C.int(idx), C.double(arg.Value.(float64)), &errMsg)
|
|
|
|
case []byte:
|
|
|
|
blob := arg.Value.([]byte)
|
|
|
|
nativeBlob := C.CBytes(blob)
|
|
|
|
statusCode = C.libsql_bind_blob(stmt, C.int(idx), (*C.uchar)(nativeBlob), C.int(len(blob)), &errMsg)
|
|
|
|
C.free(nativeBlob)
|
|
|
|
case string:
|
|
|
|
valueStr := C.CString(arg.Value.(string))
|
|
|
|
statusCode = C.libsql_bind_string(stmt, C.int(idx), valueStr, &errMsg)
|
|
|
|
C.free(unsafe.Pointer(valueStr))
|
|
|
|
case nil:
|
|
|
|
statusCode = C.libsql_bind_null(stmt, C.int(idx), &errMsg)
|
|
|
|
default:
|
|
|
|
return nil, fmt.Errorf("unsupported type %T", arg.Value)
|
|
|
|
}
|
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprintf("failed to bind argument no. %d with value %v and type %T", idx, arg.Value, arg.Value), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
var rows C.libsql_rows_t
|
2024-01-29 14:21:15 +01:00
|
|
|
statusCode = C.libsql_execute_stmt(stmt, &rows, &errMsg)
|
2024-01-25 11:05:40 +01:00
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
return rows, nil
|
|
|
|
}
|
|
|
|
|
2024-01-25 14:44:37 +01:00
|
|
|
type execResult struct {
|
|
|
|
id int64
|
|
|
|
changes int64
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r execResult) LastInsertId() (int64, error) {
|
|
|
|
return r.id, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r execResult) RowsAffected() (int64, error) {
|
|
|
|
return r.changes, nil
|
|
|
|
}
|
|
|
|
|
2023-08-08 15:51:01 +02:00
|
|
|
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
|
2024-01-25 11:05:40 +01:00
|
|
|
rows, err := c.execute(query, args)
|
2023-08-05 11:30:28 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2024-01-25 14:44:37 +01:00
|
|
|
id := int64(C.libsql_last_insert_rowid(c.nativePtr))
|
|
|
|
changes := int64(C.libsql_changes(c.nativePtr))
|
2023-07-31 16:58:49 +02:00
|
|
|
if rows != nil {
|
|
|
|
C.libsql_free_rows(rows)
|
|
|
|
}
|
2024-01-25 14:44:37 +01:00
|
|
|
return execResult{id, changes}, nil
|
|
|
|
}
|
|
|
|
|
2024-01-29 15:27:42 +01:00
|
|
|
type stmt struct {
|
|
|
|
conn *conn
|
|
|
|
sql string
|
|
|
|
numInput int
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) Close() error {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) NumInput() int {
|
|
|
|
return s.numInput
|
|
|
|
}
|
|
|
|
|
|
|
|
func convertToNamed(args []sqldriver.Value) []sqldriver.NamedValue {
|
|
|
|
if len(args) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
result := make([]sqldriver.NamedValue, 0, len(args))
|
|
|
|
for idx := range args {
|
|
|
|
result = append(result, sqldriver.NamedValue{Ordinal: idx, Value: args[idx]})
|
|
|
|
}
|
|
|
|
return result
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
|
|
|
|
return s.ExecContext(context.Background(), convertToNamed(args))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
|
|
|
|
return s.QueryContext(context.Background(), convertToNamed(args))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
|
|
|
|
return s.conn.ExecContext(ctx, s.sql, args)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
|
|
|
|
return s.conn.QueryContext(ctx, s.sql, args)
|
|
|
|
}
|
|
|
|
|
2024-01-25 14:44:37 +01:00
|
|
|
type tx struct {
|
|
|
|
conn *conn
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t tx) Commit() error {
|
|
|
|
_, err := t.conn.ExecContext(context.Background(), "COMMIT", nil)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t tx) Rollback() error {
|
|
|
|
_, err := t.conn.ExecContext(context.Background(), "ROLLBACK", nil)
|
|
|
|
return err
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
const (
|
|
|
|
TYPE_INT int = iota + 1
|
|
|
|
TYPE_FLOAT
|
|
|
|
TYPE_TEXT
|
|
|
|
TYPE_BLOB
|
|
|
|
TYPE_NULL
|
|
|
|
)
|
|
|
|
|
2023-08-05 11:55:52 +02:00
|
|
|
func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
|
2023-08-05 11:30:28 +02:00
|
|
|
if nativePtr == nil {
|
2024-02-08 16:39:39 +01:00
|
|
|
return &rows{nil, nil}, nil
|
2023-08-05 11:30:28 +02:00
|
|
|
}
|
2023-07-31 16:58:49 +02:00
|
|
|
columnCount := int(C.libsql_column_count(nativePtr))
|
2024-01-20 11:23:08 +01:00
|
|
|
columns := make([]string, columnCount)
|
2023-07-31 16:58:49 +02:00
|
|
|
for i := 0; i < columnCount; i++ {
|
2023-08-05 11:55:52 +02:00
|
|
|
var ptr *C.char
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
columns[i] = C.GoString(ptr)
|
|
|
|
C.libsql_free_string(ptr)
|
|
|
|
}
|
2024-02-08 16:39:39 +01:00
|
|
|
return &rows{nativePtr, columns}, nil
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type rows struct {
|
|
|
|
nativePtr C.libsql_rows_t
|
2023-08-05 11:55:52 +02:00
|
|
|
columnNames []string
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (r *rows) Columns() []string {
|
2023-08-05 11:55:52 +02:00
|
|
|
return r.columnNames
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (r *rows) Close() error {
|
|
|
|
if r.nativePtr != nil {
|
|
|
|
C.libsql_free_rows(r.nativePtr)
|
|
|
|
r.nativePtr = nil
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-08-08 15:51:01 +02:00
|
|
|
func (r *rows) Next(dest []sqldriver.Value) error {
|
2023-07-31 16:58:49 +02:00
|
|
|
if r.nativePtr == nil {
|
|
|
|
return io.EOF
|
|
|
|
}
|
2023-08-05 15:10:14 +02:00
|
|
|
var row C.libsql_row_t
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError("failed to get next row", statusCode, errMsg)
|
|
|
|
}
|
2023-07-31 16:58:49 +02:00
|
|
|
if row == nil {
|
|
|
|
r.Close()
|
|
|
|
return io.EOF
|
|
|
|
}
|
|
|
|
defer C.libsql_free_row(row)
|
|
|
|
count := len(dest)
|
2024-01-20 11:23:08 +01:00
|
|
|
if count > len(r.columnNames) {
|
|
|
|
count = len(r.columnNames)
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
for i := 0; i < count; i++ {
|
2024-02-08 16:39:39 +01:00
|
|
|
var columnType C.int
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_column_type(r.nativePtr, row, C.int(i), &columnType, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
|
|
|
|
switch int(columnType) {
|
2023-07-31 16:58:49 +02:00
|
|
|
case TYPE_NULL:
|
|
|
|
dest[i] = nil
|
|
|
|
case TYPE_INT:
|
2023-08-05 15:36:32 +02:00
|
|
|
var value C.longlong
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
dest[i] = int64(value)
|
2023-07-31 16:58:49 +02:00
|
|
|
case TYPE_FLOAT:
|
2023-08-05 15:42:15 +02:00
|
|
|
var value C.double
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg)
|
|
|
|
}
|
|
|
|
dest[i] = float64(value)
|
2023-07-31 16:58:49 +02:00
|
|
|
case TYPE_BLOB:
|
2023-08-05 15:49:04 +02:00
|
|
|
var nativeBlob C.blob
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg)
|
|
|
|
}
|
2023-07-31 16:58:49 +02:00
|
|
|
dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len))
|
|
|
|
C.libsql_free_blob(nativeBlob)
|
|
|
|
case TYPE_TEXT:
|
2023-08-05 15:27:53 +02:00
|
|
|
var ptr *C.char
|
|
|
|
var errMsg *C.char
|
|
|
|
statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg)
|
|
|
|
if statusCode != 0 {
|
|
|
|
return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg)
|
|
|
|
}
|
2024-01-30 11:57:01 +01:00
|
|
|
str := C.GoString(ptr)
|
2023-07-31 16:58:49 +02:00
|
|
|
C.libsql_free_string(ptr)
|
2024-01-30 11:57:01 +01:00
|
|
|
for _, format := range []string{
|
|
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
|
|
"2006-01-02T15:04:05.999999999-07:00",
|
|
|
|
"2006-01-02 15:04:05.999999999",
|
|
|
|
"2006-01-02T15:04:05.999999999",
|
|
|
|
"2006-01-02 15:04:05",
|
|
|
|
"2006-01-02T15:04:05",
|
|
|
|
"2006-01-02 15:04",
|
|
|
|
"2006-01-02T15:04",
|
|
|
|
"2006-01-02",
|
|
|
|
} {
|
|
|
|
if t, err := time.ParseInLocation(format, str, time.UTC); err == nil {
|
|
|
|
dest[i] = t
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
dest[i] = str
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-08-08 15:51:01 +02:00
|
|
|
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
|
2024-01-25 11:05:40 +01:00
|
|
|
rowsNativePtr, err := c.execute(query, args)
|
2023-08-05 11:30:28 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|
2023-08-05 11:55:52 +02:00
|
|
|
return newRows(rowsNativePtr)
|
2023-07-31 16:58:49 +02:00
|
|
|
}
|