mirror of
https://github.com/tursodatabase/libsql.git
synced 2025-09-20 23:49:47 +00:00
Merge pull request #1603 from tursodatabase/vector-search-more-tests
add more tests for vector feature
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3459,6 +3459,7 @@ dependencies = [
|
||||
"libsql_replication",
|
||||
"parking_lot",
|
||||
"pprof",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
|
@@ -230,6 +230,26 @@ do_execsql_test vector-vacuum {
|
||||
SELECT COUNT(*) FROM t_vacuum_idx_shadow;
|
||||
} {2 2}
|
||||
|
||||
do_execsql_test vector-many-columns {
|
||||
CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) );
|
||||
CREATE INDEX t_many_1_idx ON t_many(libsql_vector_idx(e1));
|
||||
CREATE INDEX t_many_2_idx ON t_many(libsql_vector_idx(e2));
|
||||
INSERT INTO t_many VALUES (1, vector('[1,1]'), vector('[-1,-1]')), (2, vector('[-1,-1]'), vector('[1,1]'));
|
||||
SELECT * FROM vector_top_k('t_many_1_idx', vector('[1,1]'), 2);
|
||||
SELECT * FROM vector_top_k('t_many_2_idx', vector('[1,1]'), 2);
|
||||
} {1 2 2 1}
|
||||
|
||||
do_execsql_test vector-transaction {
|
||||
CREATE TABLE t_transaction ( i INTEGER PRIMARY KEY, e FLOAT32(2) );
|
||||
CREATE INDEX t_transaction_idx ON t_transaction(libsql_vector_idx(e));
|
||||
INSERT INTO t_transaction VALUES (1, vector('[1,2]')), (2, vector('[3,4]'));
|
||||
BEGIN;
|
||||
INSERT INTO t_transaction VALUES (3, vector('[4,5]')), (4, vector('[5,6]'));
|
||||
SELECT * FROM vector_top_k('t_transaction_idx', vector('[4,5]'), 2);
|
||||
ROLLBACK;
|
||||
SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2);
|
||||
} {3 4 1 2}
|
||||
|
||||
proc error_messages {sql} {
|
||||
set ret ""
|
||||
catch {
|
||||
|
@@ -50,6 +50,7 @@ tokio = { version = "1.29.1", features = ["full"] }
|
||||
tokio-test = "0.4"
|
||||
tracing-subscriber = "0.3"
|
||||
tempfile = { version = "3.7.0" }
|
||||
rand = "0.8.5"
|
||||
|
||||
[features]
|
||||
default = ["core", "replication", "remote"]
|
||||
|
@@ -6,6 +6,9 @@ use libsql::{
|
||||
params::{IntoParams, IntoValue},
|
||||
Connection, Database, Value,
|
||||
};
|
||||
use rand::distributions::Uniform;
|
||||
use rand::prelude::*;
|
||||
use std::collections::HashSet;
|
||||
|
||||
async fn setup() -> Connection {
|
||||
let db = Database::open(":memory:").unwrap();
|
||||
@@ -650,3 +653,102 @@ async fn deserialize_row() {
|
||||
assert_eq!(data.status, Status::Draft);
|
||||
assert_eq!(data.wrapper, Wrapper(Status::Published));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
// fuzz test can be run explicitly with following command:
|
||||
// cargo test vector_fuzz_test -- --nocapture --include-ignored
|
||||
async fn vector_fuzz_test() {
|
||||
let mut global_rng = rand::thread_rng();
|
||||
for attempt in 0..10000 {
|
||||
let seed = global_rng.next_u64();
|
||||
|
||||
let mut rng =
|
||||
rand::rngs::StdRng::from_seed(unsafe { std::mem::transmute([seed, seed, seed, seed]) });
|
||||
let db = Database::open(":memory:").unwrap();
|
||||
let conn = db.connect().unwrap();
|
||||
let dim = rng.gen_range(1..=1536);
|
||||
let operations = rng.gen_range(1..128);
|
||||
println!(
|
||||
"============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================",
|
||||
attempt, seed, dim, operations
|
||||
);
|
||||
|
||||
let _ = conn
|
||||
.execute(
|
||||
&format!(
|
||||
"CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )",
|
||||
dim
|
||||
),
|
||||
(),
|
||||
)
|
||||
.await;
|
||||
// println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
|
||||
let _ = conn
|
||||
.execute(
|
||||
"CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );",
|
||||
(),
|
||||
)
|
||||
.await;
|
||||
// println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");
|
||||
|
||||
let mut next_id = 1;
|
||||
let mut alive = HashSet::new();
|
||||
let uniform = Uniform::new(-1.0, 1.0);
|
||||
for _ in 0..operations {
|
||||
let operation = rng.gen_range(0..4);
|
||||
let vector: Vec<f32> = (0..dim).map(|_| rng.sample(uniform)).collect();
|
||||
let vector_str = format!(
|
||||
"[{}]",
|
||||
vector
|
||||
.iter()
|
||||
.map(|x| format!("{}", x))
|
||||
.collect::<Vec<String>>()
|
||||
.join(",")
|
||||
);
|
||||
if operation == 0 {
|
||||
// println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
|
||||
conn.execute(
|
||||
"INSERT INTO users VALUES (?, vector(?) )",
|
||||
libsql::params![next_id, vector_str],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
alive.insert(next_id);
|
||||
next_id += 1;
|
||||
} else if operation == 1 {
|
||||
let id = rng.gen_range(0..next_id);
|
||||
// println!("DELETE FROM users WHERE id = {};", id);
|
||||
conn.execute("DELETE FROM users WHERE id = ?", libsql::params![id])
|
||||
.await
|
||||
.unwrap();
|
||||
alive.remove(&id);
|
||||
} else if operation == 2 && !alive.is_empty() {
|
||||
let id = alive.iter().collect::<Vec<_>>()[rng.gen_range(0..alive.len())];
|
||||
// println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
|
||||
conn.execute(
|
||||
"UPDATE users SET v = vector(?) WHERE id = ?",
|
||||
libsql::params![vector_str, id],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
} else if operation == 3 {
|
||||
let k = rng.gen_range(1..200);
|
||||
// println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
|
||||
let result = conn
|
||||
.query(
|
||||
"SELECT * FROM vector_top_k('users_idx', ?, ?)",
|
||||
libsql::params![vector_str, k],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = result.into_stream().count().await;
|
||||
assert!(count <= alive.len());
|
||||
if alive.len() > 0 {
|
||||
assert!(count > 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = conn.execute("REINDEX users;", ()).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user