0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2024-12-15 07:29:41 +00:00

83 lines
3.1 KiB
Python

import sys
import numpy as np
def recall_uniform(dim, n, q):
n = int(n)
q = int(q)
dim = int(dim)
print(f'CREATE TABLE data ( id INTEGER PRIMARY KEY, emb FLOAT32({dim}) );')
print(f'CREATE INDEX data_idx ON data( libsql_vector_idx(emb) );')
print(f'CREATE TABLE queries ( emb FLOAT32({dim}) );')
print(f'BEGIN TRANSACTION;')
for i in range(n):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'INSERT INTO data VALUES ({i}, vector(\'{vector}\'));')
for i in range(q):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'INSERT INTO queries VALUES (vector(\'{vector}\'));')
print(f'COMMIT;')
print('---insert everything')
def recall_normal(dim, n, q):
n = int(n)
q = int(q)
dim = int(dim)
print(f'CREATE TABLE data ( id INTEGER PRIMARY KEY, emb FLOAT32({dim}) );')
print(f'CREATE TABLE queries ( emb FLOAT32({dim}) );')
print(f'BEGIN TRANSACTION;')
for i in range(n):
vector = f"[{','.join(map(str, np.random.uniform(size=64)))}]"
print(f'INSERT INTO data VALUES ({i}, \'{vector}\');')
for i in range(q):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=64)))}]"
print(f'INSERT INTO queries VALUES (\'{vector}\');')
print(f'COMMIT;')
print('---insert everything')
def no_vectors(n, q):
n = int(n)
q = int(q)
print('PRAGMA journal_mode=WAL;')
print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, value TEXT );')
for i in range(n):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=64)))}]"
print(f'INSERT INTO x VALUES ({i}, \'{vector}\');')
print('---inserts')
for i in range(q):
print(f'SELECT id, value FROM x WHERE id = {np.random.randint(n)};')
print('---search')
def bruteforce(dim, n, q):
dim = int(dim)
n = int(n)
q = int(q)
print('PRAGMA journal_mode=WAL;')
print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, embedding FLOAT32({dim}) );')
for i in range(n):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'INSERT INTO x VALUES ({i}, vector(\'{vector}\'));')
print('---inserts')
for i in range(q):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'SELECT id FROM x ORDER BY vector_distance_cos(embedding, vector(\'{vector}\')) LIMIT 1;')
print('---search')
def diskann(dim, n, q):
dim = int(dim)
n = int(n)
q = int(q)
print('PRAGMA journal_mode=WAL;')
print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, embedding FLOAT32({dim}) );')
print(f"CREATE INDEX x_idx ON x( libsql_vector_idx(embedding) );")
for i in range(n):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'INSERT INTO x VALUES ({i}, vector(\'{vector}\'));')
print('---inserts')
for i in range(q):
vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]"
print(f'SELECT id FROM vector_top_k(\'x_idx\', vector(\'{vector}\'), 1);')
print('---search')
if __name__ == '__main__':
globals()[sys.argv[1]](*sys.argv[2:])