0
0
mirror of https://gitlab.com/cznic/sqlite.git synced 2025-04-27 23:07:44 +00:00
Files
go-sqlite/functest/func_test.go
Romain LE DISEZ 0a8da4bb8f fix error handling in test "QueryContext with context expiring"
While the query had succeeded, the context can expire right after. In
this situation, iterating the rows will fail. The test need to check for
sql.Rows.Err() too.

Also, don't forget to close sql.Rows when done to avoid memory leak.
2024-02-22 13:31:53 +00:00

873 lines
20 KiB
Go

// Copyright 2022 The Sqlite Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package functest // modernc.org/sqlite/functest
import (
"bytes"
"context"
"crypto/md5"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
"os"
"path"
"regexp"
"strings"
"testing"
"time"
sqlite3 "modernc.org/sqlite"
)
var finalCalled bool
type sumFunction struct {
sum int64
finalCalled *bool
}
func (f *sumFunction) Step(ctx *sqlite3.FunctionContext, args []driver.Value) error {
switch resTyped := args[0].(type) {
case int64:
f.sum += resTyped
default:
return fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)
}
return nil
}
func (f *sumFunction) WindowInverse(ctx *sqlite3.FunctionContext, args []driver.Value) error {
switch resTyped := args[0].(type) {
case int64:
f.sum -= resTyped
default:
return fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)
}
return nil
}
func (f *sumFunction) WindowValue(ctx *sqlite3.FunctionContext) (driver.Value, error) {
return f.sum, nil
}
func (f *sumFunction) Final(ctx *sqlite3.FunctionContext) {
*f.finalCalled = true
}
func init() {
sqlite3.MustRegisterDeterministicScalarFunction(
"test_int64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return int64(42), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_float64",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return float64(1e-2), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_null",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_error",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return nil, errors.New("boom")
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte{}, nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_byte_slice",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return []byte("abcdefg"), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_empty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"test_nonempty_string",
0,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
return "abcdefg", nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"yesterday",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg time.Time
switch argTyped := args[0].(type) {
case int64:
arg = time.Unix(argTyped, 0)
default:
fmt.Println(argTyped)
return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped)
}
return arg.Add(-24 * time.Hour), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"md5",
1,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var arg *bytes.Buffer
switch argTyped := args[0].(type) {
case string:
arg = bytes.NewBuffer([]byte(argTyped))
case []byte:
arg = bytes.NewBuffer(argTyped)
default:
return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped)
}
w := md5.New()
if _, err := arg.WriteTo(w); err != nil {
return nil, fmt.Errorf("unable to compute md5 checksum: %s", err)
}
return hex.EncodeToString(w.Sum(nil)), nil
},
)
sqlite3.MustRegisterDeterministicScalarFunction(
"regexp",
2,
func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
var s1 string
var s2 string
switch arg0 := args[0].(type) {
case string:
s1 = arg0
default:
return nil, errors.New("expected argv[0] to be text")
}
fmt.Println(args)
switch arg1 := args[1].(type) {
case string:
s2 = arg1
default:
return nil, errors.New("expected argv[1] to be text")
}
matched, err := regexp.MatchString(s1, s2)
if err != nil {
return nil, fmt.Errorf("bad regular expression: %q", err)
}
return matched, nil
},
)
sqlite3.MustRegisterFunction("test_sum", &sqlite3.FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx sqlite3.FunctionContext) (sqlite3.AggregateFunction, error) {
return &sumFunction{finalCalled: &finalCalled}, nil
},
})
sqlite3.MustRegisterFunction("test_aggregate_error", &sqlite3.FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx sqlite3.FunctionContext) (sqlite3.AggregateFunction, error) {
return nil, errors.New("boom")
},
})
sqlite3.MustRegisterFunction("test_aggregate_null_pointer", &sqlite3.FunctionImpl{
NArgs: 1,
Deterministic: true,
MakeAggregate: func(ctx sqlite3.FunctionContext) (sqlite3.AggregateFunction, error) {
return nil, nil
},
})
}
func TestRegisteredFunctions(t *testing.T) {
withDB := func(test func(db *sql.DB)) {
db, err := sql.Open("sqlite", "file::memory:")
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
defer db.Close()
finalCalled = false
test(db)
}
t.Run("int64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_int64()")
var a int
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 42; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("float64", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_float64()")
var a float64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, 1e-2; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("select test_error()")
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("empty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty byte slice")
}
})
})
t.Run("nonempty_byte_slice", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_byte_slice()")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("empty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_empty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if len(a) > 0 {
tt.Fatal("expected empty string")
}
})
})
t.Run("nonempty_string", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_nonempty_string()")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "abcdefg"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("null", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_null()")
var a interface{}
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if a != nil {
tt.Fatal("expected nil")
}
})
})
t.Run("dates", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))")
var a int64
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) {
tt.Fatal(g, e)
}
})
})
t.Run("md5", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select md5('abcdefg')")
var a string
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("md5 with blob input", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select md5(b) from t")
var a []byte
if err := row.Scan(&a); err != nil {
tt.Fatal(err)
}
if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) {
tt.Fatal(string(g), string(e))
}
})
})
t.Run("regexp filter", func(tt *testing.T) {
withDB(func(db *sql.DB) {
t1 := "seafood"
t2 := "fruit"
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", t1, t2); err != nil {
tt.Fatal(err)
}
rows, err := db.Query("select * from t where b regexp 'foo.*'")
if err != nil {
tt.Fatal(err)
}
type rec struct {
b string
}
var a []rec
for rows.Next() {
var r rec
if err := rows.Scan(&r.b); err != nil {
tt.Fatal(err)
}
a = append(a, r)
}
if err := rows.Err(); err != nil {
tt.Fatal(err)
}
if g, e := len(a), 1; g != e {
tt.Fatal(g, e)
}
if g, e := a[0].b, t1; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp matches", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select 'seafood' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 1; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp does not match", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select 'fruit' regexp 'foo.*'")
var r int
if err := row.Scan(&r); err != nil {
tt.Fatal(err)
}
if g, e := r, 0; g != e {
tt.Fatal(g, e)
}
})
})
t.Run("regexp errors on bad regexp", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("select 'seafood' regexp 'a(b'")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("regexp errors on bad first argument", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("SELECT 1 REGEXP 'a(b'")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("regexp errors on bad second argument", func(tt *testing.T) {
withDB(func(db *sql.DB) {
_, err := db.Query("SELECT 'seafood' REGEXP 1")
if err == nil {
tt.Fatal(errors.New("expected error, got none"))
}
})
})
t.Run("sumFunction type error", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_sum('foo');")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "string") {
tt.Fatal(err)
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
t.Run("sumFunction multiple columns", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(a int64, b int64); insert into t values (1, 5), (2, 6), (3, 7), (4, 8)"); err != nil {
tt.Fatal(err)
}
row := db.QueryRow("select test_sum(a), test_sum(b) from t;")
var a, b int64
var e int64 = 10
var f int64 = 26
if err := row.Scan(&a, &b); err != nil {
tt.Fatal(err)
}
if a != e {
tt.Fatal(a, e)
}
if b != f {
tt.Fatal(b, f)
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
t.Run("sumFunction as window function", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t3(x, y); insert into t3 values('a', 4), ('b', 5), ('c', 3), ('d', 8), ('e', 1);"); err != nil {
tt.Fatal(err)
}
rows, err := db.Query("select x, test_sum(y) over (order by x rows between 1 preceding and 1 following) as sum_y from t3 order by x;")
if err != nil {
tt.Fatal(err)
}
defer rows.Close()
type row struct {
x string
sumY int64
}
got := make([]row, 0)
for rows.Next() {
var r row
if err := rows.Scan(&r.x, &r.sumY); err != nil {
tt.Fatal(err)
}
got = append(got, r)
}
want := []row{
{"a", 9},
{"b", 12},
{"c", 16},
{"d", 12},
{"e", 9},
}
if len(got) != len(want) {
tt.Fatal(len(got), len(want))
}
for i, g := range got {
if g != want[i] {
tt.Fatal(i, g, want[i])
}
}
if !finalCalled {
t.Error("xFinal not called")
}
})
})
t.Run("aggregate function cannot be created", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_aggregate_error(1);")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "boom") {
tt.Fatal(err)
}
})
})
t.Run("null aggregate function pointer", func(tt *testing.T) {
withDB(func(db *sql.DB) {
row := db.QueryRow("select test_aggregate_null_pointer(1);")
err := row.Scan()
if err == nil {
tt.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), "MakeAggregate function returned nil") {
tt.Fatal(err)
}
})
})
t.Run("serialize and deserialize", func(tt *testing.T) {
type serializer interface {
Serialize() ([]byte, error)
Deserialize([]byte) error
}
var buf []byte
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
// Serialize the DB into buf
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
var err error
buf, err = driverConn.(serializer).Serialize()
return err
})
if err != nil {
t.Fatal(err)
}
}()
})
// Deserialize will be executed on a new in-memory DB
withDB(func(db *sql.DB) {
// Deserialize buf into the DB
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
return driverConn.(serializer).Deserialize(buf)
})
if err != nil {
t.Fatal(err)
}
}()
// Note sqlite3_deserialize will disconnect all connections from
// the database, so force to recreate a connection here. It means
// all connections attached to the db object are now invalid. Do
// not close it because it will be closed by withDB. It's a bit
// weird to handle.
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
// Check the table is complete
row := conn.QueryRowContext(context.Background(), "select count(*) from t")
var count int
if err := row.Scan(&count); err != nil {
tt.Fatal(err)
}
if count != 2 {
tt.Fatal(count, 2)
}
})
})
t.Run("backup and restore", func(tt *testing.T) {
type backuper interface {
NewBackup(string) (*sqlite3.Backup, error)
NewRestore(string) (*sqlite3.Backup, error)
}
tmpDir, err := os.MkdirTemp(os.TempDir(), "storetest_")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
tmpFile := path.Join(tmpDir, "test.db")
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
// Backup the DB to tmpFile
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
bck, err := driverConn.(backuper).NewBackup(tmpFile)
if err != nil {
return err
}
for more := true; more; {
more, err = bck.Step(-1)
if err != nil {
return err
}
}
return bck.Finish()
})
if err != nil {
t.Fatal(err)
}
}()
})
// Restore will be executed on a new in-memory DB
withDB(func(db *sql.DB) {
func() {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
err = conn.Raw(func(driverConn any) error {
bck, err := driverConn.(backuper).NewRestore(tmpFile)
if err != nil {
return err
}
for more := true; more; {
more, err = bck.Step(-1)
if err != nil {
return err
}
}
return bck.Finish()
})
if err != nil {
t.Fatal(err)
}
}()
// Check the table is complete
var count int
row := db.QueryRow("select count(*) from t")
if err := row.Scan(&count); err != nil {
tt.Fatal(err)
} else if count != 2 {
tt.Fatal(count, 2)
}
})
})
t.Run("QueryContext with context expired", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = conn.QueryContext(ctx, "select count(*) from t")
if err == nil {
tt.Fatalf("unexpected success while context is canceled")
}
})
})
t.Run("QueryContext with context expiring", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
for try := 0; try < 1000; try++ {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
for start := time.Now(); time.Since(start) < 200*time.Millisecond; {
func() {
rows, err := conn.QueryContext(ctx, "select count(*) from t")
if rows != nil {
defer rows.Close()
}
if err != nil && !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if err == nil && !rows.Next() {
if rowsErr := rows.Err(); rowsErr != nil && !(errors.Is(rowsErr, context.Canceled) || errors.Is(rowsErr, context.DeadlineExceeded) || strings.Contains(rowsErr.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if rowsErr == nil {
tt.Fatalf("success with no data (try=%d)", try)
}
}
}()
}
}
})
})
t.Run("ExecContext with context expired", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = conn.ExecContext(ctx, "insert into t values (?)", "text3")
if err == nil {
tt.Fatalf("unexpected success while context is canceled")
} else if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error while context is canceled: expected context or interrupted, got %v", err)
}
})
})
t.Run("ExecContext with context expiring", func(tt *testing.T) {
withDB(func(db *sql.DB) {
if _, err := db.Exec("create table t(b text); insert into t values (?), (?)", "text1", "text2"); err != nil {
tt.Fatal(err)
}
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
for try := 0; try < 1000; try++ {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
for start := time.Now(); time.Since(start) < 200*time.Millisecond; {
res, err := conn.ExecContext(ctx, "insert into t values (?)", "text3")
if err != nil && !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || strings.Contains(err.Error(), "interrupted (9)")) {
tt.Fatalf("unexpected error, was expected context or interrupted error: err=%v", err)
} else if err == nil {
if rowsAffected, err := res.RowsAffected(); err != nil {
tt.Fatal(err)
} else if rowsAffected != 1 {
tt.Fatalf("success with no inserted data (try=%d)", try)
}
}
}
}
})
})
}