0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2025-06-21 13:23:59 +00:00
Files
libsql/bindings/go/libsql.go

653 lines
17 KiB
Go
Raw Permalink Normal View History

2023-06-30 10:50:43 +03:00
//go:build cgo
// +build cgo
2023-06-30 09:59:25 +03:00
package libsql
/*
#cgo CFLAGS: -I../c/include
libsql_server,bottomless: add encryption support (#928) * namespace,replication: add LogFile encryption Anything that uses our LogFile format can now be encrypted on-disk. Tested locally by seeing that `wallog` file contains garbage and no sensible plaintext strings can be extracted from it. * test fixups * libsql-ffi: add libsql_generate_initial_vector and... ... libsql_generate_aes256_key to make them reachable from Rust. * connection: expose additional encryption symbols * libsql-server: derive aes256 from user passphrase properly And by properly, I mean calling back to SQLite3MultipleCiphers' code. * replication: rename Encryptor to FrameEncryptor Encryptor sounds a little too generic for this specific use case. * replication: add snapshot encryption It uses the same mechanism as wallog encryption, now abstracted away to libsql-replication crate to be reused. * replication: add an encryption feature for compilation * cargo fmt pass * fix remaining SnapshotFile::open calls in tests * logger: add an encryption test * replication: use a single buffer for encryption Ideally we could even encrypt in place, but WalPage is also used in snapshots and it's buffered, and that makes it exceptionally annoying to explain to the borrow checker. * bottomless: restore with libsql_replication::injector ... instead of the transaction page cache. That gives us free encryption, since the injector is encryption-aware. This patch doesn't hook encryption_key parameter yet, it will come in the next patch. * bottomless: pass the encryption key in options For WAL restoration, but also to be able to encrypt data that gets sent to S3. * bottomless: inherit encryption key from db config if not specified * libsql-sys: add db_change_counter() The helper function calls the underlying C API to extract 4 bytes from offset 24 of the database header and return it. It's the database change counter, which we can use to compare two databases and decide which one is newer than the other. * bottomless: use sqlite API to read database metadata With encryption enabled, we can no longer just go ahead and read data from given offsets, we must go through the VFS layer instead. Fortunately, we can just open a database connection and ask for all the metadata we need. * libsql-sys: make db change counter actually read from the db file * bottomless: treat change counter == 1 as a new database ... which it is, after setting the journal mode. Otherwise we decide too eagerly that the local database is the source of truth. * libsql-server: fix a local embedded replica test rebase conflict with encryption * bottomless-cli: allow passing the encryption key * replication: rebase new test to the new api * snapshots: do not try to decrypt headers They are not encrypted, so we shouldn't attempt to decrypt the data. * logger: restore encrypted frames during recovery Instead of decrypting and encrypting back, we just copy encrypted frames as is during the recovery process, saves IO. * compaction: clear unused encryption_key parameter It wasn't used since for compaction we only need headers, which are unencrypted. * replication: switch to FrameBorrowed::new_zeroed Following MarinPostma's suggestion. Co-authored-by: Marin Postma <postma.marin@protonmail.com> * replication: rebase chores, fixing parameters * libsql-replication: use page_mut() to decrypt data in-place * rustfmt * bottomless: use 0 for disabling autocheckpoint ... instead of u32::MAX. Effectively it's similar, but 0 is the correct choice. * rustfmt * libsql-server: make cbc, aes optional for encryption only * post-rebase fixes * libsql-replication: suppress warnings when no encryption * libsql: add encryption support for local databases * libsql: add bytes dependency for encryption * libsql-ffi: build libsqlite3mc without debug symbols Technically it should just depend on cargo build mode, but that's left for a follow-up. * bindings: an attempt to compile bindings with releasemode ... partially to save space, but also to make them faster. --------- Co-authored-by: Marin Postma <postma.marin@protonmail.com>
2024-02-09 15:27:39 +01:00
#cgo LDFLAGS: -L../../target/release
#cgo LDFLAGS: -lsql_experimental
2023-10-16 10:15:49 -07:00
#cgo LDFLAGS: -L../../libsql-sqlite3/.libs
#cgo LDFLAGS: -lsqlite3
2023-06-30 10:50:43 +03:00
#cgo LDFLAGS: -lm
2023-06-30 09:59:25 +03:00
#include <libsql.h>
#include <stdlib.h>
2023-06-30 09:59:25 +03:00
*/
import "C"
import (
"context"
2023-06-30 09:59:25 +03:00
"database/sql"
sqldriver "database/sql/driver"
2023-06-30 10:50:43 +03:00
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/libsql/sqlite-antlr4-parser/sqliteparser"
"github.com/libsql/sqlite-antlr4-parser/sqliteparserutils"
"io"
"net/url"
"regexp"
"strings"
"time"
"unsafe"
2023-06-30 09:59:25 +03:00
)
func init() {
sql.Register("libsql", driver{})
2023-06-30 09:59:25 +03:00
}
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, 0)
}
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, syncInterval)
}
type driver struct{}
func (d driver) Open(dbAddress string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbAddress)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}
func (d driver) OpenConnector(dbAddress string) (sqldriver.Connector, error) {
if strings.HasPrefix(dbAddress, ":memory:") {
return openLocalConnector(dbAddress)
}
u, err := url.Parse(dbAddress)
if err != nil {
return nil, err
}
switch u.Scheme {
case "file":
return openLocalConnector(dbAddress)
case "http":
fallthrough
case "https":
fallthrough
case "libsql":
authToken := u.Query().Get("authToken")
u.RawQuery = ""
return openRemoteConnector(u.String(), authToken)
}
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file:, https:// or http://", u.Scheme)
}
func libsqlSync(nativeDbPtr C.libsql_database_t) error {
var errMsg *C.char
statusCode := C.libsql_sync(nativeDbPtr, &errMsg)
if statusCode != 0 {
return libsqlError("failed to sync database ", statusCode, errMsg)
}
return nil
}
func openLocalConnector(dbPath string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenLocal(dbPath)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}
func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenRemote(primaryUrl, authToken)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}
func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var closeCh chan struct{}
var closeAckCh chan struct{}
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
}
}
}()
}
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
}
type Connector struct {
nativeDbPtr C.libsql_database_t
closeCh chan<- struct{}
closeAckCh <-chan struct{}
}
func (c *Connector) Sync() error {
return libsqlSync(c.nativeDbPtr)
}
func (c *Connector) Close() error {
if c.closeCh != nil {
c.closeCh <- struct{}{}
<-c.closeAckCh
c.closeCh = nil
c.closeAckCh = nil
}
if c.nativeDbPtr != nil {
C.libsql_close(c.nativeDbPtr)
}
c.nativeDbPtr = nil
return nil
}
func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
nativeConnPtr, err := libsqlConnect(c.nativeDbPtr)
if err != nil {
return nil, err
}
return &conn{nativePtr: nativeConnPtr}, nil
}
func (c *Connector) Driver() sqldriver.Driver {
return driver{}
}
2023-06-30 09:59:25 +03:00
func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
code := int(statusCode)
if errMsg != nil {
msg := C.GoString(errMsg)
C.libsql_free_string(errMsg)
return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg)
} else {
return fmt.Errorf("%s\nerror code = %d", message, code)
}
}
func libsqlOpenLocal(dataSourceName string) (C.libsql_database_t, error) {
2023-06-30 09:59:25 +03:00
connectionString := C.CString(dataSourceName)
defer C.free(unsafe.Pointer(connectionString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_file(connectionString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open local database ", dataSourceName), statusCode, errMsg)
}
return db, nil
}
func libsqlOpenRemote(url, authToken string) (C.libsql_database_t, error) {
connectionString := C.CString(url)
defer C.free(unsafe.Pointer(connectionString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_remote(connectionString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open remote database ", url), statusCode, errMsg)
}
return db, nil
}
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
dbPathNativeString := C.CString(dbPath)
defer C.free(unsafe.Pointer(dbPathNativeString))
primaryUrlNativeString := C.CString(primaryUrl)
defer C.free(unsafe.Pointer(primaryUrlNativeString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
}
return db, nil
}
func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) {
var conn C.libsql_connection_t
var errMsg *C.char
statusCode := C.libsql_connect(db, &conn, &errMsg)
if statusCode != 0 {
return nil, libsqlError("failed to connect to database", statusCode, errMsg)
}
return conn, nil
}
2023-06-30 09:59:25 +03:00
type conn struct {
nativePtr C.libsql_connection_t
}
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
2023-06-30 10:50:43 +03:00
}
func (c *conn) Begin() (sqldriver.Tx, error) {
return c.BeginTx(context.Background(), sqldriver.TxOptions{})
2023-06-30 10:50:43 +03:00
}
func (c *conn) Close() error {
C.libsql_disconnect(c.nativePtr)
return nil
2023-06-30 10:50:43 +03:00
}
type ParamsInfo struct {
NamedParameters []string
PositionalParametersCount int
}
func isPositionalParameter(param string) (ok bool, err error) {
re := regexp.MustCompile(`\?([0-9]*).*`)
match := re.FindSubmatch([]byte(param))
if match == nil {
return false, nil
}
posS := string(match[1])
if posS == "" {
return true, nil
}
return true, fmt.Errorf("unsuppoted positional parameter. This driver does not accept positional parameters with indexes (like ?<number>)")
}
func removeParamPrefix(paramName string) (string, error) {
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
return paramName[1:], nil
}
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
}
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
statementStream := antlr.NewInputStream(stmt)
sqliteparser.NewSQLiteLexer(statementStream)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
allTokens := lexer.GetAllTokens()
nameParamsSet := make(map[string]bool)
for _, token := range allTokens {
tokenType := token.GetTokenType()
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
parameter := token.GetText()
isPositionalParameter, err := isPositionalParameter(parameter)
if err != nil {
return []string{}, 0, err
}
if isPositionalParameter {
positionalParamsCount++
} else {
paramWithoutPrefix, err := removeParamPrefix(parameter)
if err != nil {
return []string{}, 0, err
} else {
nameParamsSet[paramWithoutPrefix] = true
}
}
}
}
nameParams = make([]string, 0, len(nameParamsSet))
for k := range nameParamsSet {
nameParams = append(nameParams, k)
}
return nameParams, positionalParamsCount, nil
}
func parseStatement(sql string) ([]string, []ParamsInfo, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
stmtsParams := make([]ParamsInfo, len(stmts))
for idx, stmt := range stmts {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return nil, nil, err
}
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
}
return stmts, stmtsParams, nil
}
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
stmts, paramInfos, err := parseStatement(query)
if err != nil {
return nil, err
}
if len(stmts) != 1 {
return nil, fmt.Errorf("only one statement is supported got %d", len(stmts))
}
numInput := -1
if len(paramInfos[0].NamedParameters) == 0 {
numInput = paramInfos[0].PositionalParametersCount
}
return &stmt{c, query, numInput}, nil
2023-06-30 10:50:43 +03:00
}
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
if opts.ReadOnly {
return nil, fmt.Errorf("read only transactions are not supported")
}
if opts.Isolation != sqldriver.IsolationLevel(sql.LevelDefault) {
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
}
_, err := c.ExecContext(ctx, "BEGIN", nil)
if err != nil {
return nil, err
}
return &tx{c}, nil
2023-06-30 09:59:25 +03:00
}
func (c *conn) executeNoArgs(query string) (C.libsql_rows_t, error) {
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
var rows C.libsql_rows_t
var errMsg *C.char
statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}
func (c *conn) execute(query string, args []sqldriver.NamedValue) (C.libsql_rows_t, error) {
if len(args) == 0 {
return c.executeNoArgs(query)
}
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
var stmt C.libsql_stmt_t
var errMsg *C.char
statusCode := C.libsql_prepare(c.nativePtr, queryCString, &stmt, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to prepare query ", query), statusCode, errMsg)
}
defer C.libsql_free_stmt(stmt)
for _, arg := range args {
var errMsg *C.char
var statusCode C.int
idx := arg.Ordinal
switch arg.Value.(type) {
case int64:
statusCode = C.libsql_bind_int(stmt, C.int(idx), C.longlong(arg.Value.(int64)), &errMsg)
case float64:
statusCode = C.libsql_bind_float(stmt, C.int(idx), C.double(arg.Value.(float64)), &errMsg)
case []byte:
blob := arg.Value.([]byte)
nativeBlob := C.CBytes(blob)
statusCode = C.libsql_bind_blob(stmt, C.int(idx), (*C.uchar)(nativeBlob), C.int(len(blob)), &errMsg)
C.free(nativeBlob)
case string:
valueStr := C.CString(arg.Value.(string))
statusCode = C.libsql_bind_string(stmt, C.int(idx), valueStr, &errMsg)
C.free(unsafe.Pointer(valueStr))
case nil:
statusCode = C.libsql_bind_null(stmt, C.int(idx), &errMsg)
default:
return nil, fmt.Errorf("unsupported type %T", arg.Value)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to bind argument no. %d with value %v and type %T", idx, arg.Value, arg.Value), statusCode, errMsg)
}
}
var rows C.libsql_rows_t
statusCode = C.libsql_execute_stmt(stmt, &rows, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}
type execResult struct {
id int64
changes int64
}
func (r execResult) LastInsertId() (int64, error) {
return r.id, nil
}
func (r execResult) RowsAffected() (int64, error) {
return r.changes, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
rows, err := c.execute(query, args)
if err != nil {
return nil, err
}
id := int64(C.libsql_last_insert_rowid(c.nativePtr))
changes := int64(C.libsql_changes(c.nativePtr))
if rows != nil {
C.libsql_free_rows(rows)
}
return execResult{id, changes}, nil
}
type stmt struct {
conn *conn
sql string
numInput int
}
func (s *stmt) Close() error {
return nil
}
func (s *stmt) NumInput() int {
return s.numInput
}
func convertToNamed(args []sqldriver.Value) []sqldriver.NamedValue {
if len(args) == 0 {
return nil
}
result := make([]sqldriver.NamedValue, 0, len(args))
for idx := range args {
result = append(result, sqldriver.NamedValue{Ordinal: idx, Value: args[idx]})
}
return result
}
func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
return s.ExecContext(context.Background(), convertToNamed(args))
}
func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
return s.QueryContext(context.Background(), convertToNamed(args))
}
func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
return s.conn.ExecContext(ctx, s.sql, args)
}
func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
return s.conn.QueryContext(ctx, s.sql, args)
}
type tx struct {
conn *conn
}
func (t tx) Commit() error {
_, err := t.conn.ExecContext(context.Background(), "COMMIT", nil)
return err
}
func (t tx) Rollback() error {
_, err := t.conn.ExecContext(context.Background(), "ROLLBACK", nil)
return err
}
const (
TYPE_INT int = iota + 1
TYPE_FLOAT
TYPE_TEXT
TYPE_BLOB
TYPE_NULL
)
func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
if nativePtr == nil {
return &rows{nil, nil}, nil
}
columnCount := int(C.libsql_column_count(nativePtr))
columns := make([]string, columnCount)
for i := 0; i < columnCount; i++ {
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg)
}
columns[i] = C.GoString(ptr)
C.libsql_free_string(ptr)
}
return &rows{nativePtr, columns}, nil
}
type rows struct {
nativePtr C.libsql_rows_t
columnNames []string
}
func (r *rows) Columns() []string {
return r.columnNames
}
func (r *rows) Close() error {
if r.nativePtr != nil {
C.libsql_free_rows(r.nativePtr)
r.nativePtr = nil
}
return nil
}
func (r *rows) Next(dest []sqldriver.Value) error {
if r.nativePtr == nil {
return io.EOF
}
var row C.libsql_row_t
var errMsg *C.char
statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg)
if statusCode != 0 {
return libsqlError("failed to get next row", statusCode, errMsg)
}
if row == nil {
r.Close()
return io.EOF
}
defer C.libsql_free_row(row)
count := len(dest)
if count > len(r.columnNames) {
count = len(r.columnNames)
}
for i := 0; i < count; i++ {
var columnType C.int
var errMsg *C.char
statusCode := C.libsql_column_type(r.nativePtr, row, C.int(i), &columnType, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
}
switch int(columnType) {
case TYPE_NULL:
dest[i] = nil
case TYPE_INT:
var value C.longlong
var errMsg *C.char
statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg)
}
dest[i] = int64(value)
case TYPE_FLOAT:
var value C.double
var errMsg *C.char
statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg)
}
dest[i] = float64(value)
case TYPE_BLOB:
var nativeBlob C.blob
var errMsg *C.char
statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg)
}
dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len))
C.libsql_free_blob(nativeBlob)
case TYPE_TEXT:
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg)
}
str := C.GoString(ptr)
C.libsql_free_string(ptr)
for _, format := range []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
} {
if t, err := time.ParseInLocation(format, str, time.UTC); err == nil {
dest[i] = t
return nil
}
}
dest[i] = str
}
}
return nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
rowsNativePtr, err := c.execute(query, args)
if err != nil {
return nil, err
}
return newRows(rowsNativePtr)
}