1
0
mirror of https://gitlab.com/cznic/sqlite.git synced 2024-09-20 01:20:51 +00:00
go-sqlite/func_test.go
2024-07-22 08:07:14 +02:00

964 lines
22 KiB
Go

// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sqlite // import "modernc.org/sqlite"
import (
"bytes"
"context"
"crypto/md5"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"io"
"os"
"path"
"regexp"
"strings"
"testing"
"time"
)
var finalCalled bool
type sumFunction struct {
sum int64
finalCalled *bool
}
func (f *sumFunction) Step(ctx *FunctionContext, args []driver.Value) error {
switch resTyped := args[0].(type) {
case int64:
f.sum += resTyped
default:
return fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)
}
return nil
}
func (f *sumFunction) WindowInverse(ctx *FunctionContext, args []driver.Value) error {
switch resTyped := args[0].(type) {
case int64:
f.sum -= resTyped
default:
return fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)
}
return nil
}
func (f *sumFunction) WindowValue(ctx *FunctionContext) (driver.Value, error) {
return f.sum, nil
}
func (f *sumFunction) Final(ctx *FunctionContext) {
*f.finalCalled = true
}
func init() {
MustRegisterDeterministicScalarFunction(
"test_int64",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return int64(42), nil
},
)
MustRegisterDeterministicScalarFunction(
"test_float64",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return float64(1e-2), nil
},
)
MustRegisterDeterministicScalarFunction(
"test_null",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, nil
},
)
MustRegisterDeterministicScalarFunction(
"test_error",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, errors.New("boom")
},
)
MustRegisterDeterministicScalarFunction(
"test_empty_byte_slice",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte{}, nil
},
)
MustRegisterDeterministicScalarFunction(
"test_nonempty_byte_slice",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte("abcdefg"), nil
},
)
MustRegisterDeterministicScalarFunction(
"test_empty_string",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return "", nil
},
)
MustRegisterDeterministicScalarFunction(
"test_nonempty_string",
0,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
return "abcdefg", nil
},
)
MustRegisterDeterministicScalarFunction(
"yesterday",
1,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
var arg time.Time
switch argTyped := args[0].(type) {
case int64:
arg = time.Unix(argTyped, 0)
default:
fmt.Println(argTyped)
return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped)
}
return arg.Add(-24 * time.Hour), nil
},
)
MustRegisterDeterministicScalarFunction(
"md5",
1,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
var arg *bytes.Buffer
switch argTyped := args[0].(type) {
case string:
arg = bytes.NewBuffer([]byte(argTyped))
case []byte:
arg = bytes.NewBuffer(argTyped)
default:
return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped)
}
w := md5.New()
if _, err := arg.WriteTo(w); err != nil {
return nil, fmt.Errorf("unable to compute md5 checksum: %s", err)
}
return hex.EncodeToString(w.Sum(nil)), nil
},
)
MustRegisterDeterministicScalarFunction(
"regexp",
2,
func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) {
var s1 string
var s2 string
switch arg0 := args[0].(type) {
case string:
s1 = arg0
default:
return nil, errors.New("expected argv[0] to be text")
}
fmt.Println(args)
switch arg1 := args[1].(type) {
case string:
s2 = arg1
default:
return nil, errors.New("expected argv[1] to be text")
}
matched, err := regexp.MatchString(s1, s2)
if err != nil {
return nil, fmt.Errorf("bad regular expression: %q", err)
}
return matched, nil
},
)
MustRegisterFunction("test_sum", &FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) {
return &sumFunction{finalCalled: &finalCalled}, nil
},
})
MustRegisterFunction("test_aggregate_error", &FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) {
return nil, errors.New("boom")
},
})
MustRegisterFunction("test_aggregate_null_pointer", &FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx FunctionContext) (AggregateFunction, error) {
return nil, nil
},
})
}
func TestRegisteredFunctions(t *testing.T) {
withDB := func(test func(db *sql.DB)) {
db, err := sql.Open("sqlite", "file::memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
finalCalled = false
test(db)
}
t.Run("int64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_int64()")
var a int
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 42; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("float64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_float64()")
var a float64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 1e-2; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("select test_error()")
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("empty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty byte slice")
}
})
})
t.Run("nonempty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("empty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty string")
}
})
})
t.Run("nonempty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "abcdefg"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("null", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_null()")
var a interface{}
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if a != nil {
tt.Fatal("expected nil")
}
})
})
t.Run("dates", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))")
var a int64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) {
tt.Fatal(g, e)
}
})
})
t.Run("md5", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select md5('abcdefg')")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("md5 with blob input", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select md5(b) from t")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("regexp filter", func(tt *testing.T) {
withDB(func(db *sql.DB) {
t1 := "seafood"
t2 := "fruit"
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", t1, t2); err != nil {
tt.Fatal(err)
}
rows, err := db.Query("select * from t where b regexp 'foo.*'")
if err != nil {
tt.Fatal(err)
}
type rec struct {
b string
}
var a []rec
for rows.Next() {
var r rec
if err := rows.Scan(&r.b); err != nil {
tt.Fatal(err)
}
a = append(a, r)
}
if err := rows.Err(); err != nil {
tt.Fatal(err)
}
if g, e := len(a), 1; g != e {
tt.Fatal(g, e)
}
if g, e := a[0].b, t1; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp matches", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select 'seafood' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 1; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp does not match", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select 'fruit' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 0; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp errors on bad regexp", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("select 'seafood' regexp 'a(b'")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("regexp errors on bad first argument", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("SELECT 1 REGEXP 'a(b'")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("regexp errors on bad second argument", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("SELECT 'seafood' REGEXP 1")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("sumFunction type error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_sum('foo');")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "string") {
tt.Fatal(err)
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
t.Run("sumFunction multiple columns", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(a int64, b int64); insert into t values (1, 5), (2, 6), (3, 7), (4, 8)"); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select test_sum(a), test_sum(b) from t;")
var a, b int64
var e int64 = 10
var f int64 = 26
if err := row.Scan(&a, &b); err != nil {
tt.Fatal(err)
}
if a != e {
tt.Fatal(a, e)
}
if b != f {
tt.Fatal(b, f)
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
t.Run("sumFunction as window function", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t3(x, y); insert into t3 values('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1);"); err != nil {
tt.Fatal(err)
}
rows, err := db.Query("select x, test_sum(y) over (order by x rows between 1 preceding and 1 following) as sum_y from t3 order by x;")
if err != nil {
tt.Fatal(err)
}
defer rows.Close()
type row struct {
x string
sumY int64
}
got := make([]row, 0)
for rows.Next() {
var r row
if err := rows.Scan(&r.x, &r.sumY); err != nil {
tt.Fatal(err)
}
got = append(got, r)
}
want := []row{
{"a", 9},
{"b", 12},
{"c", 16},
{"d", 12},
{"e", 9},
}
if len(got) != len(want) {
tt.Fatal(len(got), len(want))
}
for i, g := range got {
if g != want[i] {
tt.Fatal(i, g, want[i])
}
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
t.Run("aggregate function cannot be created", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_aggregate_error(1);")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("null aggregate function pointer", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_aggregate_null_pointer(1);")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "MakeAggregate function returned nil") {
tt.Fatal(err)
}
})
})
t.Run("serialize and deserialize", func(tt *testing.T) {
type serializer interface {
Serialize() ([]byte, error)
Deserialize([]byte) error
}
var buf []byte
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
// Serialize the DB into buf
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
var err error
buf, err = driverConn.(serializer).Serialize()
return err
})
if err != nil {
t.Fatal(err)
}
}()
})
// Deserialize will be executed on a new in-memory DB
withDB(func(db *sql.DB) {
// Deserialize buf into the DB
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
return driverConn.(serializer).Deserialize(buf)
})
if err != nil {
t.Fatal(err)
}
}()
// Note sqlite3_deserialize will disconnect all connections from
// the database, so force to recreate a connection here. It means
// all connections attached to the db object are now invalid. Do
// not close it because it will be closed by withDB. It's a bit
// weird to handle.
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
// Check the table is complete
row := conn.QueryRowContext(context.Background(), "select count(*) from t")
var count int
if err := row.Scan(&count); err != nil {
tt.Fatal(err)
}
if count != 2 {
tt.Fatal(count, 2)
}
})
})
t.Run("backup and restore", func(tt *testing.T) {
type backuper interface {
NewBackup(string) (*Backup, error)
NewRestore(string) (*Backup, error)
}
tmpDir, err := os.MkdirTemp(os.TempDir(), "storetest_")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
tmpFile := path.Join(tmpDir, "test.db")
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
// Backup the DB to tmpFile
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
bck, err := driverConn.(backuper).NewBackup(tmpFile)
if err != nil {
return err
}
for more := true; more; {
more, err = bck.Step(-1)
if err != nil {
return err
}
}
return bck.Finish()
})
if err != nil {
t.Fatal(err)
}
}()
})
// Restore will be executed on a new in-memory DB
withDB(func(db *sql.DB) {
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
bck, err := driverConn.(backuper).NewRestore(tmpFile)
if err != nil {
return err
}
for more := true; more; {
more, err = bck.Step(-1)
if err != nil {
return err
}
}
return bck.Finish()
})
if err != nil {
t.Fatal(err)
}
}()
// Check the table is complete
var count int
row := db.QueryRow("select count(*) from t")
if err := row.Scan(&count); err != nil {
tt.Fatal(err)
} else if count != 2 {
tt.Fatal(count, 2)
}
})
})
t.Run("backup, commit and close", func(tt *testing.T) {
type backuper interface {
NewBackup(string) (*Backup, error)
NewRestore(string) (*Backup, error)
}
tmpDir, err := os.MkdirTemp(os.TempDir(), "storetest_")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
var inMemDB driver.Conn
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
// Backup the DB to an in-memory instance
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
bck, err := driverConn.(backuper).NewBackup(":memory:")
if err != nil {
return err
}
for more := true; more; {
more, err = bck.Step(-1)
if err != nil {
return err
}
}
inMemDB, err = bck.Commit()
return err
})
if err != nil {
t.Fatal(err)
}
}()
})
if inMemDB == nil {
tt.Fatal("in-memory database must not be nil")
}
querier, ok := (inMemDB).(interface {
Query(query string, args []driver.Value) (dr driver.Rows, err error)
})
if !ok {
tt.Fatal(fmt.Errorf("in-memory DB does not implement a Query method: %T", inMemDB))
}
// Check the table is complete
rows, err := querier.Query("select count(*) from t", nil)
if err != nil {
tt.Fatal(err)
}
var count int
var values = []driver.Value{&count}
err = rows.Next(values)
if err != nil {
tt.Fatal(err)
}
if (values[0]).(int64) != 2 {
tt.Fatal(count, 2)
}
err = rows.Next(values)
if !errors.Is(err, io.EOF) {
tt.Fatal(err)
}
if err = rows.Close(); err != nil {
tt.Fatal(err)
}
if err = inMemDB.Close(); err != nil {
tt.Fatal(err)
}
})
t.Run("QueryContext with context expired", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = conn.QueryContext(ctx, "select count(*) from t")
if err == nil {
tt.Fatalf("unexpected success while context is canceled")
}
})
})
t.Run("QueryContext with context expiring", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
for try := 0; try < 1000; try++ {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
for start := time.Now(); time.Since(start) < 200*time.Millisecond; {
func() {
rows, err := conn.QueryContext(ctx, "select count(*) from t")
if rows != nil {
defer rows.Close()
}
if err != nil && !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if err == nil && !rows.Next() {
if rowsErr := rows.Err(); rowsErr != nil && !(errors.Is(rowsErr, context.Canceled) || errors.Is(rowsErr, context.DeadlineExceeded) || strings.Contains(rowsErr.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if rowsErr == nil {
tt.Fatalf("success with no data (try=%d)", try)
}
}
}()
}
}
})
})
t.Run("ExecContext with context expired", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = conn.ExecContext(ctx, "insert into t values (?)", "text3")
if err == nil {
tt.Fatalf("unexpected success while context is canceled")
} else if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error while context is canceled: expected context or interrupted, got %v", err)
}
})
})
t.Run("ExecContext with context expiring", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
for try := 0; try < 1000; try++ {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
for start := time.Now(); time.Since(start) < 200*time.Millisecond; {
res, err := conn.ExecContext(ctx, "insert into t values (?)", "text3")
if err != nil && !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if err == nil {
if rowsAffected, err := res.RowsAffected(); err != nil {
tt.Fatal(err)
} else if rowsAffected != 1 {
tt.Fatalf("success with no inserted data (try=%d)", try)
}
}
}
}
})
})
}