0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2025-05-17 13:36:57 +00:00
Files
libsql/bindings/go/libsql.go
2023-10-16 10:38:34 -07:00

363 lines
9.3 KiB
Go

//go:build cgo
// +build cgo
package libsql
/*
#cgo CFLAGS: -I../c/include
#cgo LDFLAGS: -L../../target/debug
#cgo LDFLAGS: -lsql_experimental
#cgo LDFLAGS: -L../../libsql-sqlite3/.libs
#cgo LDFLAGS: -lsqlite3
#cgo LDFLAGS: -lm
#include <libsql.h>
#include <stdlib.h>
*/
import "C"
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"time"
"unsafe"
)
func init() {
sql.Register("libsql", driver{})
}
func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, 0)
}
func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
return openConnector(dbPath, primaryUrl, authToken, syncInterval)
}
type driver struct{}
func (d driver) Open(dbPath string) (sqldriver.Conn, error) {
connector, err := d.OpenConnector(dbPath)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}
func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) {
return openConnector(dbPath, "", "", 0)
}
func libsqlSync(nativeDbPtr C.libsql_database_t) error {
var errMsg *C.char
statusCode := C.libsql_sync(nativeDbPtr, &errMsg)
if statusCode != 0 {
return libsqlError("failed to sync database ", statusCode, errMsg)
}
return nil
}
func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) {
var nativeDbPtr C.libsql_database_t
var err error
var closeCh chan struct{}
var closeAckCh chan struct{}
if primaryUrl != "" {
nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken)
if err != nil {
return nil, err
}
if err := libsqlSync(nativeDbPtr); err != nil {
C.libsql_close(nativeDbPtr)
return nil, err
}
if syncInterval != 0 {
closeCh = make(chan struct{}, 1)
closeAckCh = make(chan struct{}, 1)
go func() {
for {
timerCh := make(chan struct{}, 1)
go func() {
time.Sleep(syncInterval)
timerCh <- struct{}{}
}()
select {
case <-closeCh:
closeAckCh <- struct{}{}
return
case <-timerCh:
if err := libsqlSync(nativeDbPtr); err != nil {
fmt.Println(err)
}
}
}
}()
}
} else {
nativeDbPtr, err = libsqlOpen(dbPath)
}
if err != nil {
return nil, err
}
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
}
type Connector struct {
nativeDbPtr C.libsql_database_t
closeCh chan<- struct{}
closeAckCh <-chan struct{}
}
func (c *Connector) Sync() error {
return libsqlSync(c.nativeDbPtr)
}
func (c *Connector) Close() error {
if c.closeCh != nil {
c.closeCh <- struct{}{}
<-c.closeAckCh
}
C.libsql_close(c.nativeDbPtr)
return nil
}
func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
nativeConnPtr, err := libsqlConnect(c.nativeDbPtr)
if err != nil {
return nil, err
}
return &conn{nativePtr: nativeConnPtr}, nil
}
func (c *Connector) Driver() sqldriver.Driver {
return driver{}
}
func libsqlError(message string, statusCode C.int, errMsg *C.char) error {
code := int(statusCode)
if errMsg != nil {
msg := C.GoString(errMsg)
C.libsql_free_string(errMsg)
return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg)
} else {
return fmt.Errorf("%s\nerror code = %d", message, code)
}
}
func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) {
connectionString := C.CString(dataSourceName)
defer C.free(unsafe.Pointer(connectionString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_ext(connectionString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to open database ", dataSourceName), statusCode, errMsg)
}
return db, nil
}
func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) {
dbPathNativeString := C.CString(dbPath)
defer C.free(unsafe.Pointer(dbPathNativeString))
primaryUrlNativeString := C.CString(primaryUrl)
defer C.free(unsafe.Pointer(primaryUrlNativeString))
authTokenNativeString := C.CString(authToken)
defer C.free(unsafe.Pointer(authTokenNativeString))
var db C.libsql_database_t
var errMsg *C.char
statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg)
}
return db, nil
}
func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) {
var conn C.libsql_connection_t
var errMsg *C.char
statusCode := C.libsql_connect(db, &conn, &errMsg)
if statusCode != 0 {
return nil, libsqlError("failed to connect to database", statusCode, errMsg)
}
return conn, nil
}
type conn struct {
nativePtr C.libsql_connection_t
}
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) Begin() (sqldriver.Tx, error) {
return c.BeginTx(context.Background(), sqldriver.TxOptions{})
}
func (c *conn) Close() error {
C.libsql_disconnect(c.nativePtr)
return nil
}
func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) {
return nil, fmt.Errorf("prepare() is not implemented")
}
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
return nil, fmt.Errorf("begin() is not implemented")
}
func (c *conn) execute(query string) (C.libsql_rows_t, error) {
queryCString := C.CString(query)
defer C.free(unsafe.Pointer(queryCString))
var rows C.libsql_rows_t
var errMsg *C.char
statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg)
}
return rows, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) {
rows, err := c.execute(query)
if err != nil {
return nil, err
}
if rows != nil {
C.libsql_free_rows(rows)
}
return nil, nil
}
const (
TYPE_INT int = iota + 1
TYPE_FLOAT
TYPE_TEXT
TYPE_BLOB
TYPE_NULL
)
func newRows(nativePtr C.libsql_rows_t) (*rows, error) {
if nativePtr == nil {
return &rows{nil, nil, nil}, nil
}
columnCount := int(C.libsql_column_count(nativePtr))
columnTypes := make([]int, columnCount)
for i := 0; i < columnCount; i++ {
var columnType C.int
var errMsg *C.char
statusCode := C.libsql_column_type(nativePtr, C.int(i), &columnType, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg)
}
columnTypes[i] = int(columnType)
}
columns := make([]string, len(columnTypes))
for i := 0; i < len(columnTypes); i++ {
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg)
}
columns[i] = C.GoString(ptr)
C.libsql_free_string(ptr)
}
return &rows{nativePtr, columnTypes, columns}, nil
}
type rows struct {
nativePtr C.libsql_rows_t
columnTypes []int
columnNames []string
}
func (r *rows) Columns() []string {
return r.columnNames
}
func (r *rows) Close() error {
if r.nativePtr != nil {
C.libsql_free_rows(r.nativePtr)
r.nativePtr = nil
}
return nil
}
func (r *rows) Next(dest []sqldriver.Value) error {
if r.nativePtr == nil {
return io.EOF
}
var row C.libsql_row_t
var errMsg *C.char
statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg)
if statusCode != 0 {
return libsqlError("failed to get next row", statusCode, errMsg)
}
if row == nil {
r.Close()
return io.EOF
}
defer C.libsql_free_row(row)
count := len(dest)
if count > len(r.columnTypes) {
count = len(r.columnTypes)
}
for i := 0; i < count; i++ {
switch r.columnTypes[i] {
case TYPE_NULL:
dest[i] = nil
case TYPE_INT:
var value C.longlong
var errMsg *C.char
statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg)
}
dest[i] = int64(value)
case TYPE_FLOAT:
var value C.double
var errMsg *C.char
statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg)
}
dest[i] = float64(value)
case TYPE_BLOB:
var nativeBlob C.blob
var errMsg *C.char
statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg)
}
dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len))
C.libsql_free_blob(nativeBlob)
case TYPE_TEXT:
var ptr *C.char
var errMsg *C.char
statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg)
}
dest[i] = C.GoString(ptr)
C.libsql_free_string(ptr)
}
}
return nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
rowsNativePtr, err := c.execute(query)
if err != nil {
return nil, err
}
return newRows(rowsNativePtr)
}