mirror of
https://github.com/tursodatabase/libsql.git
synced 2025-05-17 15:56:55 +00:00
363 lines
9.3 KiB
Go
363 lines
9.3 KiB
Go
//go:build cgo
|
|
// +build cgo
|
|
|
|
package libsql
|
|
|
|
/*
|
|
#cgo CFLAGS: -I../c/include
|
|
#cgo LDFLAGS: -L../../target/debug
|
|
#cgo LDFLAGS: -lsql_experimental
|
|
#cgo LDFLAGS: -L../../libsql-sqlite3/.libs
|
|
#cgo LDFLAGS: -lsqlite3
|
|
#cgo LDFLAGS: -lm
|
|
#include <libsql.h>
|
|
#include <stdlib.h>
|
|
*/
|
|
import "C"
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
sqldriver "database/sql/driver"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
"unsafe"
|
|
)
|
|
|
|
func init() {
|
|
sql.Register("libsql", driver{})
|
|
}
|
|
|
|
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
|
|
return openConnector(dbPath, primaryUrl, authToken, 0)
|
|
}
|
|
|
|
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
|
|
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
|
|
}
|
|
|
|
type driver struct{}
|
|
|
|
func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
|
|
connector, err := d.OpenConnector(dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return connector.Connect(context.Background())
|
|
}
|
|
|
|
func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
|
|
return openConnector(dbPath, "", "", 0)
|
|
}
|
|
|
|
func libsqlSync(nativeDbPtr C.libsql_database_t) error {
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_sync(nativeDbPtr, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError("failed to sync database ", statusCode, errMsg)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
|
|
var nativeDbPtr C.libsql_database_t
|
|
var err error
|
|
var closeCh chan struct{}
|
|
var closeAckCh chan struct{}
|
|
if primaryUrl != "" {
|
|
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := libsqlSync(nativeDbPtr); err != nil {
|
|
C.libsql_close(nativeDbPtr)
|
|
return nil, err
|
|
}
|
|
if syncInterval != 0 {
|
|
closeCh = make(chan struct{}, 1)
|
|
closeAckCh = make(chan struct{}, 1)
|
|
go func() {
|
|
for {
|
|
timerCh := make(chan struct{}, 1)
|
|
go func() {
|
|
time.Sleep(syncInterval)
|
|
timerCh <- struct{}{}
|
|
}()
|
|
select {
|
|
case <-closeCh:
|
|
closeAckCh <- struct{}{}
|
|
return
|
|
case <-timerCh:
|
|
if err := libsqlSync(nativeDbPtr); err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
} else {
|
|
nativeDbPtr, err = libsqlOpen(dbPath)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
|
|
}
|
|
|
|
type Connector struct {
|
|
nativeDbPtr C.libsql_database_t
|
|
closeCh chan<- struct{}
|
|
closeAckCh <-chan struct{}
|
|
}
|
|
|
|
func (c *Connector) Sync() error {
|
|
return libsqlSync(c.nativeDbPtr)
|
|
}
|
|
|
|
func (c *Connector) Close() error {
|
|
if c.closeCh != nil {
|
|
c.closeCh <- struct{}{}
|
|
<-c.closeAckCh
|
|
}
|
|
C.libsql_close(c.nativeDbPtr)
|
|
return nil
|
|
}
|
|
|
|
func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
|
|
nativeConnPtr, err := libsqlConnect(c.nativeDbPtr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &conn{nativePtr: nativeConnPtr}, nil
|
|
}
|
|
|
|
func (c *Connector) Driver() sqldriver.Driver {
|
|
return driver{}
|
|
}
|
|
|
|
func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
|
|
code := int(statusCode)
|
|
if errMsg != nil {
|
|
msg := C.GoString(errMsg)
|
|
C.libsql_free_string(errMsg)
|
|
return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg)
|
|
} else {
|
|
return fmt.Errorf("%s\nerror code = %d", message, code)
|
|
}
|
|
}
|
|
|
|
func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
|
|
connectionString := C.CString(dataSourceName)
|
|
defer C.free(unsafe.Pointer(connectionString))
|
|
|
|
var db C.libsql_database_t
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_open_ext(connectionString, &db, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError(fmt.Sprint("failed to open database ", dataSourceName), statusCode, errMsg)
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
|
|
dbPathNativeString := C.CString(dbPath)
|
|
defer C.free(unsafe.Pointer(dbPathNativeString))
|
|
primaryUrlNativeString := C.CString(primaryUrl)
|
|
defer C.free(unsafe.Pointer(primaryUrlNativeString))
|
|
authTokenNativeString := C.CString(authToken)
|
|
defer C.free(unsafe.Pointer(authTokenNativeString))
|
|
|
|
var db C.libsql_database_t
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) {
|
|
var conn C.libsql_connection_t
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_connect(db, &conn, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError("failed to connect to database", statusCode, errMsg)
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
type conn struct {
|
|
nativePtr C.libsql_connection_t
|
|
}
|
|
|
|
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
|
|
return c.PrepareContext(context.Background(), query)
|
|
}
|
|
|
|
func (c *conn) Begin() (sqldriver.Tx, error) {
|
|
return c.BeginTx(context.Background(), sqldriver.TxOptions{})
|
|
}
|
|
|
|
func (c *conn) Close() error {
|
|
C.libsql_disconnect(c.nativePtr)
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
|
|
return nil, fmt.Errorf("prepare() is not implemented")
|
|
}
|
|
|
|
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
|
|
return nil, fmt.Errorf("begin() is not implemented")
|
|
}
|
|
|
|
func (c *conn) execute(query string) (C.libsql_rows_t, error) {
|
|
queryCString := C.CString(query)
|
|
defer C.free(unsafe.Pointer(queryCString))
|
|
|
|
var rows C.libsql_rows_t
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
|
|
rows, err := c.execute(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rows != nil {
|
|
C.libsql_free_rows(rows)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
const (
|
|
TYPE_INT int = iota + 1
|
|
TYPE_FLOAT
|
|
TYPE_TEXT
|
|
TYPE_BLOB
|
|
TYPE_NULL
|
|
)
|
|
|
|
func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
|
|
if nativePtr == nil {
|
|
return &rows{nil, nil, nil}, nil
|
|
}
|
|
columnCount := int(C.libsql_column_count(nativePtr))
|
|
columnTypes := make([]int, columnCount)
|
|
for i := 0; i < columnCount; i++ {
|
|
var columnType C.int
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_column_type(nativePtr, C.int(i), &columnType, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
|
|
}
|
|
columnTypes[i] = int(columnType)
|
|
}
|
|
columns := make([]string, len(columnTypes))
|
|
for i := 0; i < len(columnTypes); i++ {
|
|
var ptr *C.char
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
|
|
if statusCode != 0 {
|
|
return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg)
|
|
}
|
|
columns[i] = C.GoString(ptr)
|
|
C.libsql_free_string(ptr)
|
|
}
|
|
return &rows{nativePtr, columnTypes, columns}, nil
|
|
}
|
|
|
|
type rows struct {
|
|
nativePtr C.libsql_rows_t
|
|
columnTypes []int
|
|
columnNames []string
|
|
}
|
|
|
|
func (r *rows) Columns() []string {
|
|
return r.columnNames
|
|
}
|
|
|
|
func (r *rows) Close() error {
|
|
if r.nativePtr != nil {
|
|
C.libsql_free_rows(r.nativePtr)
|
|
r.nativePtr = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *rows) Next(dest []sqldriver.Value) error {
|
|
if r.nativePtr == nil {
|
|
return io.EOF
|
|
}
|
|
var row C.libsql_row_t
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError("failed to get next row", statusCode, errMsg)
|
|
}
|
|
if row == nil {
|
|
r.Close()
|
|
return io.EOF
|
|
}
|
|
defer C.libsql_free_row(row)
|
|
count := len(dest)
|
|
if count > len(r.columnTypes) {
|
|
count = len(r.columnTypes)
|
|
}
|
|
for i := 0; i < count; i++ {
|
|
switch r.columnTypes[i] {
|
|
case TYPE_NULL:
|
|
dest[i] = nil
|
|
case TYPE_INT:
|
|
var value C.longlong
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg)
|
|
}
|
|
dest[i] = int64(value)
|
|
case TYPE_FLOAT:
|
|
var value C.double
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg)
|
|
}
|
|
dest[i] = float64(value)
|
|
case TYPE_BLOB:
|
|
var nativeBlob C.blob
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg)
|
|
}
|
|
dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len))
|
|
C.libsql_free_blob(nativeBlob)
|
|
case TYPE_TEXT:
|
|
var ptr *C.char
|
|
var errMsg *C.char
|
|
statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg)
|
|
if statusCode != 0 {
|
|
return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg)
|
|
}
|
|
dest[i] = C.GoString(ptr)
|
|
C.libsql_free_string(ptr)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
|
|
rowsNativePtr, err := c.execute(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newRows(rowsNativePtr)
|
|
}
|