//go:build cgo // +build cgo package libsql /* #cgo CFLAGS: -I../c/include #cgo LDFLAGS: -L../../target/debug #cgo LDFLAGS: -lsql_experimental #cgo LDFLAGS: -L../../libsql-sqlite3/.libs #cgo LDFLAGS: -lsqlite3 #cgo LDFLAGS: -lm #include <libsql.h> #include <stdlib.h> */ import "C" import ( "context" "database/sql" sqldriver "database/sql/driver" "fmt" "io" "time" "unsafe" ) func init() { sql.Register("libsql", driver{}) } func NewEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string) (*Connector, error) { return openConnector(dbPath, primaryUrl, authToken, 0) } func NewEmbeddedReplicaConnectorWithAutoSync(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) { return openConnector(dbPath, primaryUrl, authToken, syncInterval) } type driver struct{} func (d driver) Open(dbPath string) (sqldriver.Conn, error) { connector, err := d.OpenConnector(dbPath) if err != nil { return nil, err } return connector.Connect(context.Background()) } func (d driver) OpenConnector(dbPath string) (sqldriver.Connector, error) { return openConnector(dbPath, "", "", 0) } func libsqlSync(nativeDbPtr C.libsql_database_t) error { var errMsg *C.char statusCode := C.libsql_sync(nativeDbPtr, &errMsg) if statusCode != 0 { return libsqlError("failed to sync database ", statusCode, errMsg) } return nil } func openConnector(dbPath, primaryUrl, authToken string, syncInterval time.Duration) (*Connector, error) { var nativeDbPtr C.libsql_database_t var err error var closeCh chan struct{} var closeAckCh chan struct{} if primaryUrl != "" { nativeDbPtr, err = libsqlOpenWithSync(dbPath, primaryUrl, authToken) if err != nil { return nil, err } if err := libsqlSync(nativeDbPtr); err != nil { C.libsql_close(nativeDbPtr) return nil, err } if syncInterval != 0 { closeCh = make(chan struct{}, 1) closeAckCh = make(chan struct{}, 1) go func() { for { timerCh := make(chan struct{}, 1) go func() { time.Sleep(syncInterval) timerCh <- struct{}{} }() select { case <-closeCh: closeAckCh <- struct{}{} return case <-timerCh: if err := libsqlSync(nativeDbPtr); err != nil { fmt.Println(err) } } } }() } } else { nativeDbPtr, err = libsqlOpen(dbPath) } if err != nil { return nil, err } return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil } type Connector struct { nativeDbPtr C.libsql_database_t closeCh chan<- struct{} closeAckCh <-chan struct{} } func (c *Connector) Sync() error { return libsqlSync(c.nativeDbPtr) } func (c *Connector) Close() error { if c.closeCh != nil { c.closeCh <- struct{}{} <-c.closeAckCh } C.libsql_close(c.nativeDbPtr) return nil } func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) { nativeConnPtr, err := libsqlConnect(c.nativeDbPtr) if err != nil { return nil, err } return &conn{nativePtr: nativeConnPtr}, nil } func (c *Connector) Driver() sqldriver.Driver { return driver{} } func libsqlError(message string, statusCode C.int, errMsg *C.char) error { code := int(statusCode) if errMsg != nil { msg := C.GoString(errMsg) C.libsql_free_string(errMsg) return fmt.Errorf("%s\nerror code = %d: %v", message, code, msg) } else { return fmt.Errorf("%s\nerror code = %d", message, code) } } func libsqlOpen(dataSourceName string) (C.libsql_database_t, error) { connectionString := C.CString(dataSourceName) defer C.free(unsafe.Pointer(connectionString)) var db C.libsql_database_t var errMsg *C.char statusCode := C.libsql_open_ext(connectionString, &db, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprint("failed to open database ", dataSourceName), statusCode, errMsg) } return db, nil } func libsqlOpenWithSync(dbPath, primaryUrl, authToken string) (C.libsql_database_t, error) { dbPathNativeString := C.CString(dbPath) defer C.free(unsafe.Pointer(dbPathNativeString)) primaryUrlNativeString := C.CString(primaryUrl) defer C.free(unsafe.Pointer(primaryUrlNativeString)) authTokenNativeString := C.CString(authToken) defer C.free(unsafe.Pointer(authTokenNativeString)) var db C.libsql_database_t var errMsg *C.char statusCode := C.libsql_open_sync(dbPathNativeString, primaryUrlNativeString, authTokenNativeString, &db, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprintf("failed to open database %s %s", dbPath, primaryUrl), statusCode, errMsg) } return db, nil } func libsqlConnect(db C.libsql_database_t) (C.libsql_connection_t, error) { var conn C.libsql_connection_t var errMsg *C.char statusCode := C.libsql_connect(db, &conn, &errMsg) if statusCode != 0 { return nil, libsqlError("failed to connect to database", statusCode, errMsg) } return conn, nil } type conn struct { nativePtr C.libsql_connection_t } func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { return c.PrepareContext(context.Background(), query) } func (c *conn) Begin() (sqldriver.Tx, error) { return c.BeginTx(context.Background(), sqldriver.TxOptions{}) } func (c *conn) Close() error { C.libsql_disconnect(c.nativePtr) return nil } func (c *conn) PrepareContext(ctx context.Context, query string) (sqldriver.Stmt, error) { return nil, fmt.Errorf("prepare() is not implemented") } func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) { return nil, fmt.Errorf("begin() is not implemented") } func (c *conn) execute(query string) (C.libsql_rows_t, error) { queryCString := C.CString(query) defer C.free(unsafe.Pointer(queryCString)) var rows C.libsql_rows_t var errMsg *C.char statusCode := C.libsql_execute(c.nativePtr, queryCString, &rows, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprint("failed to execute query ", query), statusCode, errMsg) } return rows, nil } func (c *conn) ExecContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Result, error) { rows, err := c.execute(query) if err != nil { return nil, err } if rows != nil { C.libsql_free_rows(rows) } return nil, nil } const ( TYPE_INT int = iota + 1 TYPE_FLOAT TYPE_TEXT TYPE_BLOB TYPE_NULL ) func newRows(nativePtr C.libsql_rows_t) (*rows, error) { if nativePtr == nil { return &rows{nil, nil, nil}, nil } columnCount := int(C.libsql_column_count(nativePtr)) columnTypes := make([]int, columnCount) for i := 0; i < columnCount; i++ { var columnType C.int var errMsg *C.char statusCode := C.libsql_column_type(nativePtr, C.int(i), &columnType, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprint("failed to get column type for index ", i), statusCode, errMsg) } columnTypes[i] = int(columnType) } columns := make([]string, len(columnTypes)) for i := 0; i < len(columnTypes); i++ { var ptr *C.char var errMsg *C.char statusCode := C.libsql_column_name(nativePtr, C.int(i), &ptr, &errMsg) if statusCode != 0 { return nil, libsqlError(fmt.Sprint("failed to get column name for index ", i), statusCode, errMsg) } columns[i] = C.GoString(ptr) C.libsql_free_string(ptr) } return &rows{nativePtr, columnTypes, columns}, nil } type rows struct { nativePtr C.libsql_rows_t columnTypes []int columnNames []string } func (r *rows) Columns() []string { return r.columnNames } func (r *rows) Close() error { if r.nativePtr != nil { C.libsql_free_rows(r.nativePtr) r.nativePtr = nil } return nil } func (r *rows) Next(dest []sqldriver.Value) error { if r.nativePtr == nil { return io.EOF } var row C.libsql_row_t var errMsg *C.char statusCode := C.libsql_next_row(r.nativePtr, &row, &errMsg) if statusCode != 0 { return libsqlError("failed to get next row", statusCode, errMsg) } if row == nil { r.Close() return io.EOF } defer C.libsql_free_row(row) count := len(dest) if count > len(r.columnTypes) { count = len(r.columnTypes) } for i := 0; i < count; i++ { switch r.columnTypes[i] { case TYPE_NULL: dest[i] = nil case TYPE_INT: var value C.longlong var errMsg *C.char statusCode := C.libsql_get_int(row, C.int(i), &value, &errMsg) if statusCode != 0 { return libsqlError(fmt.Sprint("failed to get integer for column ", i), statusCode, errMsg) } dest[i] = int64(value) case TYPE_FLOAT: var value C.double var errMsg *C.char statusCode := C.libsql_get_float(row, C.int(i), &value, &errMsg) if statusCode != 0 { return libsqlError(fmt.Sprint("failed to get float for column ", i), statusCode, errMsg) } dest[i] = float64(value) case TYPE_BLOB: var nativeBlob C.blob var errMsg *C.char statusCode := C.libsql_get_blob(row, C.int(i), &nativeBlob, &errMsg) if statusCode != 0 { return libsqlError(fmt.Sprint("failed to get blob for column ", i), statusCode, errMsg) } dest[i] = C.GoBytes(unsafe.Pointer(nativeBlob.ptr), C.int(nativeBlob.len)) C.libsql_free_blob(nativeBlob) case TYPE_TEXT: var ptr *C.char var errMsg *C.char statusCode := C.libsql_get_string(row, C.int(i), &ptr, &errMsg) if statusCode != 0 { return libsqlError(fmt.Sprint("failed to get string for column ", i), statusCode, errMsg) } dest[i] = C.GoString(ptr) C.libsql_free_string(ptr) } } return nil } func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) { rowsNativePtr, err := c.execute(query) if err != nil { return nil, err } return newRows(rowsNativePtr) }