0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2025-05-14 23:52:29 +00:00
Files
libsql/bindings/go/libsql.go
Piotr Jastrzębski f3da77e705 Go bindings: separate query and exec (#1047)
Signed-off-by: Piotr Jastrzebski <piotr@chiselstrike.com>
2024-03-12 12:40:03 +00:00

764 lines
20 KiB
Go

//go:build cgo
// +build cgo
package libsql
/*
#cgo CFLAGS: -I../c/include
#cgo LDFLAGS: -L../../target/release
#cgo LDFLAGS: -lsql_experimental
#cgo LDFLAGS: -lm
#include <libsql.h>
#include <stdlib.h>
*/
import "C"
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/libsql/sqlite-antlr4-parser/sqliteparser"
"github.com/libsql/sqlite-antlr4-parser/sqliteparserutils"
"io"
"net/url"
"regexp"
"strings"
"time"
"unsafe"
)
func init() {
sql.Register("libsql", driver{})
}
type config struct {
authToken *string
readYourWrites *bool
encryptionKey *string
syncInterval *time.Duration
}
type Option interface {
apply(*config) error
}
type option func(*config) error
func (o option) apply(c *config) error {
return o(c)
}
func WithAuthToken(authToken string) Option {
return option(func(o *config) error {
if o.authToken != nil {
return fmt.Errorf("authToken already set")
}
if authToken == "" {
return fmt.Errorf("authToken must not be empty")
}
o.authToken = &authToken
return nil
})
}
func WithReadYourWrites(readYourWrites bool) Option {
return option(func(o *config) error {
if o.readYourWrites != nil {
return fmt.Errorf("read your writes already set")
}
o.readYourWrites = &readYourWrites
return nil
})
}
func WithEncryption(key string) Option {
return option(func(o *config) error {
if o.encryptionKey != nil {
return fmt.Errorf("encryption key already set")
}
if key == "" {
return fmt.Errorf("encryption key must not be empty")
}
o.encryptionKey = &key
return nil
})
}
func WithSyncInterval(interval time.Duration) Option {
return option(func(o *config) error {
if o.syncInterval != nil {
return fmt.Errorf("sync interval already set")
}
o.syncInterval = &interval
return nil
})
}
func NewEmbeddedReplicaConnector(dbPath string, primaryUrl string, opts ...Option) (*Connector, error) {
var config config
errs := make([]error, 0, len(opts))
for _, opt := range opts {
if err := opt.apply(&config); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return nil, errors.Join(errs...)
}
authToken := ""
if config.authToken != nil {
authToken = *config.authToken
}
readYourWrites := true
if config.readYourWrites != nil {
readYourWrites = *config.readYourWrites
}
encryptionKey := ""
if config.encryptionKey != nil {
encryptionKey = *config.encryptionKey
}
syncInterval := time.Duration(0)
if config.syncInterval != nil {
syncInterval = *config.syncInterval
}
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey, syncInterval)
}
type driver struct{}
func (d driver) Open(dbAddress string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbAddress)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}
func (d driver) OpenConnector(dbAddress string) (sqldriver.Connector, error) {
if strings.HasPrefix(dbAddress, ":memory:") {
return openLocalConnector(dbAddress)
}
u, err := url.Parse(dbAddress)
if err != nil {
return nil, err
}
switch u.Scheme {
case "file":
return openLocalConnector(dbAddress)
case "http":
fallthrough
case "https":
fallthrough
case "libsql":
authToken := u.Query().Get("authToken")
u.RawQuery = ""
return openRemoteConnector(u.String(), authToken)
}
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file:, https:// or http://", u.Scheme)
}
func libsqlSync(nativeDbPtr C.libsql_database_t) error {
var errMsg *C.char
statusCode := C.libsql_sync(nativeDbPtr, &errMsg)
if statusCode != 0 {
return libsqlError("failed to sync database ", statusCode, errMsg)
}
return nil
}
func openLocalConnector(dbPath string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenLocal(dbPath)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}
func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
nativeDbPtr, err := libsqlOpenRemote(primaryUrl, authToken)
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}
func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string, syncInterval time.Duration) (*Connector, error) {
var closeCh chan struct{}
var closeAckCh chan struct{}
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
}
}
}()
}
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
}
type Connector struct {
nativeDbPtr C.libsql_database_t
closeCh chan<- struct{}
closeAckCh <-chan struct{}
}
func (c *Connector) Sync() error {
return libsqlSync(c.nativeDbPtr)
}
func (c *Connector) Close() error {
if c.closeCh != nil {
c.closeCh <- struct{}{}
<-c.closeAckCh
c.closeCh = nil
c.closeAckCh = nil
}
if c.nativeDbPtr != nil {
C.libsql_close(c.nativeDbPtr)
}
c.nativeDbPtr = nil
return nil
}
func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
nativeConnPtr, err := libsqlConnect(c.nativeDbPtr)
if err != nil {
return nil, err
}
return &conn{nativePtr: nativeConnPtr}, nil
}
func (c *Connector) Driver() sqldriver.Driver {
return driver{}
}
func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
code := int(statusCode)
if errMsg != nil {
msg := C.GoString(errMsg)
C.libsql_free_string(errMsg)
return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg)
} else {
return fmt.Errorf("%s\nerror code = %d", message, code)
}
}
func libsqlOpenLocal(dataSourceName string) (C.libsql_database_t, error) {
connectionString := C.CString(dataSourceName)
defer C.free(unsafe.Pointer(connectionString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_file(connectionString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open local database ", dataSourceName), statusCode, errMsg)
}
return db, nil
}
func libsqlOpenRemote(url, authToken string) (C.libsql_database_t, error) {
connectionString := C.CString(url)
defer C.free(unsafe.Pointer(connectionString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_remote(connectionString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open remote database ", url), statusCode, errMsg)
}
return db, nil
}
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string) (C.libsql_database_t, error) {
dbPathNativeString := C.CString(dbPath)
defer C.free(unsafe.Pointer(dbPathNativeString))
primaryUrlNativeString := C.CString(primaryUrl)
defer C.free(unsafe.Pointer(primaryUrlNativeString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))
var readYourWritesNative C.char = 0
if readYourWrites {
readYourWritesNative = 1
}
var encrytionKeyNativeString *C.char
if encryptionKey != "" {
encrytionKeyNativeString = C.CString(encryptionKey)
defer C.free(unsafe.Pointer(encrytionKeyNativeString))
}
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, readYourWritesNative, encrytionKeyNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
}
return db, nil
}
func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) {
var conn C.libsql_connection_t
var errMsg *C.char
statusCode := C.libsql_connect(db, &conn, &errMsg)
if statusCode != 0 {
return nil, libsqlError("failed to connect to database", statusCode, errMsg)
}
return conn, nil
}
type conn struct {
nativePtr C.libsql_connection_t
}
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) Begin() (sqldriver.Tx, error) {
return c.BeginTx(context.Background(), sqldriver.TxOptions{})
}
func (c *conn) Close() error {
C.libsql_disconnect(c.nativePtr)
return nil
}
type ParamsInfo struct {
NamedParameters []string
PositionalParametersCount int
}
func isPositionalParameter(param string) (ok bool, err error) {
re := regexp.MustCompile(`\?([0-9]*).*`)
match := re.FindSubmatch([]byte(param))
if match == nil {
return false, nil
}
posS := string(match[1])
if posS == "" {
return true, nil
}
return true, fmt.Errorf("unsuppoted positional parameter. This driver does not accept positional parameters with indexes (like ?<number>)")
}
func removeParamPrefix(paramName string) (string, error) {
if paramName[0] == ':' || paramName[0] == '@' || paramName[0] == '$' {
return paramName[1:], nil
}
return "", fmt.Errorf("all named parameters must start with ':', or '@' or '$'")
}
func extractParameters(stmt string) (nameParams []string, positionalParamsCount int, err error) {
statementStream := antlr.NewInputStream(stmt)
sqliteparser.NewSQLiteLexer(statementStream)
lexer := sqliteparser.NewSQLiteLexer(statementStream)
allTokens := lexer.GetAllTokens()
nameParamsSet := make(map[string]bool)
for _, token := range allTokens {
tokenType := token.GetTokenType()
if tokenType == sqliteparser.SQLiteLexerBIND_PARAMETER {
parameter := token.GetText()
isPositionalParameter, err := isPositionalParameter(parameter)
if err != nil {
return []string{}, 0, err
}
if isPositionalParameter {
positionalParamsCount++
} else {
paramWithoutPrefix, err := removeParamPrefix(parameter)
if err != nil {
return []string{}, 0, err
} else {
nameParamsSet[paramWithoutPrefix] = true
}
}
}
}
nameParams = make([]string, 0, len(nameParamsSet))
for k := range nameParamsSet {
nameParams = append(nameParams, k)
}
return nameParams, positionalParamsCount, nil
}
func parseStatement(sql string) ([]string, []ParamsInfo, error) {
stmts, _ := sqliteparserutils.SplitStatement(sql)
stmtsParams := make([]ParamsInfo, len(stmts))
for idx, stmt := range stmts {
nameParams, positionalParamsCount, err := extractParameters(stmt)
if err != nil {
return nil, nil, err
}
stmtsParams[idx] = ParamsInfo{nameParams, positionalParamsCount}
}
return stmts, stmtsParams, nil
}
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
stmts, paramInfos, err := parseStatement(query)
if err != nil {
return nil, err
}
if len(stmts) != 1 {
return nil, fmt.Errorf("only one statement is supported got %d", len(stmts))
}
numInput := -1
if len(paramInfos[0].NamedParameters) == 0 {
numInput = paramInfos[0].PositionalParametersCount
}
return &stmt{c, query, numInput}, nil
}
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
if opts.ReadOnly {
return nil, fmt.Errorf("read only transactions are not supported")
}
if opts.Isolation != sqldriver.IsolationLevel(sql.LevelDefault) {
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
}
_, err := c.ExecContext(ctx, "BEGIN", nil)
if err != nil {
return nil, err
}
return &tx{c}, nil
}
func (c *conn) executeNoArgs(query string, exec bool) (C.libsql_rows_t, error) {
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
var rows C.libsql_rows_t
var errMsg *C.char
var statusCode C.int
if exec {
statusCode = C.libsql_execute(c.nativePtr, queryCString, &errMsg)
} else {
statusCode = C.libsql_query(c.nativePtr, queryCString, &rows, &errMsg)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}
func (c *conn) execute(query string, args []sqldriver.NamedValue, exec bool) (C.libsql_rows_t, error) {
if len(args) == 0 {
return c.executeNoArgs(query, exec)
}
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
var stmt C.libsql_stmt_t
var errMsg *C.char
statusCode := C.libsql_prepare(c.nativePtr, queryCString, &stmt, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to prepare query ", query), statusCode, errMsg)
}
defer C.libsql_free_stmt(stmt)
for _, arg := range args {
var errMsg *C.char
var statusCode C.int
idx := arg.Ordinal
switch arg.Value.(type) {
case int64:
statusCode = C.libsql_bind_int(stmt, C.int(idx), C.longlong(arg.Value.(int64)), &errMsg)
case float64:
statusCode = C.libsql_bind_float(stmt, C.int(idx), C.double(arg.Value.(float64)), &errMsg)
case []byte:
blob := arg.Value.([]byte)
nativeBlob := C.CBytes(blob)
statusCode = C.libsql_bind_blob(stmt, C.int(idx), (*C.uchar)(nativeBlob), C.int(len(blob)), &errMsg)
C.free(nativeBlob)
case string:
valueStr := C.CString(arg.Value.(string))
statusCode = C.libsql_bind_string(stmt, C.int(idx), valueStr, &errMsg)
C.free(unsafe.Pointer(valueStr))
case nil:
statusCode = C.libsql_bind_null(stmt, C.int(idx), &errMsg)
case bool:
var valueInt int
if arg.Value.(bool) {
valueInt = 1
} else {
valueInt = 0
}
statusCode = C.libsql_bind_int(stmt, C.int(idx), C.longlong(valueInt), &errMsg)
default:
return nil, fmt.Errorf("unsupported type %T", arg.Value)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to bind argument no. %d with value %v and type %T", idx, arg.Value, arg.Value), statusCode, errMsg)
}
}
var rows C.libsql_rows_t
if exec {
statusCode = C.libsql_execute_stmt(stmt, &errMsg)
} else {
statusCode = C.libsql_query_stmt(stmt, &rows, &errMsg)
}
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}
type execResult struct {
id int64
changes int64
}
func (r execResult) LastInsertId() (int64, error) {
return r.id, nil
}
func (r execResult) RowsAffected() (int64, error) {
return r.changes, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
rows, err := c.execute(query, args, true)
if err != nil {
return nil, err
}
id := int64(C.libsql_last_insert_rowid(c.nativePtr))
changes := int64(C.libsql_changes(c.nativePtr))
if rows != nil {
C.libsql_free_rows(rows)
}
return execResult{id, changes}, nil
}
type stmt struct {
conn *conn
sql string
numInput int
}
func (s *stmt) Close() error {
return nil
}
func (s *stmt) NumInput() int {
return s.numInput
}
func convertToNamed(args []sqldriver.Value) []sqldriver.NamedValue {
if len(args) == 0 {
return nil
}
result := make([]sqldriver.NamedValue, 0, len(args))
for idx := range args {
result = append(result, sqldriver.NamedValue{Ordinal: idx, Value: args[idx]})
}
return result
}
func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
return s.ExecContext(context.Background(), convertToNamed(args))
}
func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
return s.QueryContext(context.Background(), convertToNamed(args))
}
func (s *stmt) ExecContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Result, error) {
return s.conn.ExecContext(ctx, s.sql, args)
}
func (s *stmt) QueryContext(ctx context.Context, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
return s.conn.QueryContext(ctx, s.sql, args)
}
type tx struct {
conn *conn
}
func (t tx) Commit() error {
_, err := t.conn.ExecContext(context.Background(), "COMMIT", nil)
return err
}
func (t tx) Rollback() error {
_, err := t.conn.ExecContext(context.Background(), "ROLLBACK", nil)
return err
}
const (
TYPE_INT int = iota + 1
TYPE_FLOAT
TYPE_TEXT
TYPE_BLOB
TYPE_NULL
)
func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
if nativePtr == nil {
return &rows{nil, nil}, nil
}
columnCount := int(C.libsql_column_count(nativePtr))
columns := make([]string, columnCount)
for i := 0; i < columnCount; i++ {
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg)
}
columns[i] = C.GoString(ptr)
C.libsql_free_string(ptr)
}
return &rows{nativePtr, columns}, nil
}
type rows struct {
nativePtr C.libsql_rows_t
columnNames []string
}
func (r *rows) Columns() []string {
return r.columnNames
}
func (r *rows) Close() error {
if r.nativePtr != nil {
C.libsql_free_rows(r.nativePtr)
r.nativePtr = nil
}
return nil
}
func (r *rows) Next(dest []sqldriver.Value) error {
if r.nativePtr == nil {
return io.EOF
}
var row C.libsql_row_t
var errMsg *C.char
statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg)
if statusCode != 0 {
return libsqlError("failed to get next row", statusCode, errMsg)
}
if row == nil {
r.Close()
return io.EOF
}
defer C.libsql_free_row(row)
count := len(dest)
if count > len(r.columnNames) {
count = len(r.columnNames)
}
for i := 0; i < count; i++ {
var columnType C.int
var errMsg *C.char
statusCode := C.libsql_column_type(r.nativePtr, row, C.int(i), &columnType, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
}
switch int(columnType) {
case TYPE_NULL:
dest[i] = nil
case TYPE_INT:
var value C.longlong
var errMsg *C.char
statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg)
}
dest[i] = int64(value)
case TYPE_FLOAT:
var value C.double
var errMsg *C.char
statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg)
}
dest[i] = float64(value)
case TYPE_BLOB:
var nativeBlob C.blob
var errMsg *C.char
statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg)
}
dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len))
C.libsql_free_blob(nativeBlob)
case TYPE_TEXT:
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg)
}
str := C.GoString(ptr)
C.libsql_free_string(ptr)
for _, format := range []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
} {
if t, err := time.ParseInLocation(format, str, time.UTC); err == nil {
dest[i] = t
return nil
}
}
dest[i] = str
}
}
return nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
rowsNativePtr, err := c.execute(query, args, false)
if err != nil {
return nil, err
}
return newRows(rowsNativePtr)
}