0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2024-12-15 07:29:41 +00:00
libsql/libsql-sqlite3/test/libsql_vector.test
2024-08-22 12:32:49 +04:00

242 lines
15 KiB
Plaintext

# 2024-07-04
#
# Copyright 2024 the libSQL authors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
#***********************************************************************
# This file implements regression tests for libSQL library. The
# focus of this file is vector search.
set testdir [file dirname $argv0]
source $testdir/tester.tcl
set testprefix vector
do_execsql_test vector-1-inserts {
CREATE TABLE t1( xv FLOAT32(3) );
INSERT INTO t1(rowid,xv) VALUES(1, vector('[1,2,3]'));
INSERT INTO t1(rowid,xv) VALUES(2, vector('[2,3,4]'));
INSERT INTO t1(rowid,xv) VALUES(3, vector('[5,6,7]'));
} {}
do_execsql_test vector-1-func-valid {
SELECT vector_extract(vector('[]'));
SELECT vector_extract(vector(x''));
SELECT vector_extract(vector(' [ 1 , 2 , 3 ] '));
SELECT vector_extract(vector('[-1000000000000000000]'));
SELECT hex(vector('[1.10101010101010101010101010]'));
SELECT hex(vector32('[1.10101010101010101010101010]'));
SELECT hex(vector64('[1.10101010101010101010101010]'));
SELECT vector_extract(x'E6ED8C3F');
SELECT vector_extract(x'F37686C4BC9DF13F02');
SELECT vector_extract(vector(x'F37686C4BC9DF13F01'));
SELECT vector_distance_cos('[1,1]', '[1,1]');
SELECT vector_distance_cos('[1,1]', '[-1,-1]');
SELECT vector_distance_cos('[1,1]', '[-1,1]');
SELECT vector_distance_cos('[1,2]', '[2,1]');
SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[-5,4]'));
SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,4]'));
SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,-2]'));
SELECT vector_distance_cos(vector8('[10,-10]'), vector8('[10,-10]'));
SELECT vector_distance_cos(vector16('[10,-10]'), vector16('[10,-10]'));
SELECT vector_distance_cos(vectorb16('[10,-10]'), vectorb16('[10,-10]'));
SELECT vector_distance_cos(vector32('[10,-10]'), vector32('[10,-10]'));
SELECT vector_distance_cos(vector8('[-21,-31,0,2,2.1,2.2,105]'), vector8('[-20,-30,0,1,1.1,1.2,100]'));
SELECT vector_distance_cos(vector16('[-21,-31,0,2,2.1,2.2,105]'), vector16('[-20,-30,0,1,1.1,1.2,100]'));
SELECT vector_distance_cos(vectorb16('[-21,-31,0,2,2.1,2.2,105]'), vectorb16('[-20,-30,0,1,1.1,1.2,100]'));
SELECT vector_distance_cos(vector32('[-21,-31,0,2,2.1,2.2,105]'), vector32('[-20,-30,0,1,1.1,1.2,100]'));
SELECT vector_distance_cos(vector8('[-20,-30,0,1,1.1,1.2,100]'), vector8('[-20,-30,0,1,1.1,1.2,10000]'));
SELECT vector_distance_cos(vector16('[-20,-30,0,1,1.1,1.2,100]'), vector16('[-20,-30,0,1,1.1,1.2,10000]'));
SELECT vector_distance_cos(vectorb16('[-20,-30,0,1,1.1,1.2,100]'), vectorb16('[-20,-30,0,1,1.1,1.2,10000]'));
SELECT vector_distance_cos(vector32('[-20,-30,0,1,1.1,1.2,100]'), vector32('[-20,-30,0,1,1.1,1.2,10000]'));
SELECT vector_distance_cos(vector8('[-1000000,1000000]'), vector8('[1000000,-1000000]'));
SELECT vector_distance_cos(vector16('[-1000000,1000000]'), vector16('[1000000,-1000000]'));
SELECT vector_distance_cos(vectorb16('[-1000000,1000000]'), vectorb16('[1000000,-1000000]'));
SELECT vector_distance_cos(vector32('[-1000000,1000000]'), vector32('[1000000,-1000000]'));
SELECT vector_distance_l2(vector('[1,2,2,3,4,1,5]'), vector('[2,3,1,-1,2,4,5]'));
SELECT vector_distance_l2(vector8('[1,2,2,3,4,1,5]'), vector8('[2,3,1,-1,2,4,5]'));
SELECT vector_distance_l2(vector16('[1,2,2,3,4,1,5]'), vector16('[2,3,1,-1,2,4,5]'));
SELECT vector_distance_l2(vectorb16('[1,2,2,3,4,1,5]'), vectorb16('[2,3,1,-1,2,4,5]'));
} {
{[]}
{[]}
{[1,2,3]}
{[-1e+18]}
{E6ED8C3F}
{E6ED8C3F}
{F37686C4BC9DF13F02}
{[1.10101]}
{[1.10101]}
{[-1075.72,1.88763]}
{0.0}
{2.0}
{1.0}
{0.200000002980232}
{2.0}
{1.0}
{0.0}
{-6.10352568486405e-09} {0.0} {0.0} {0.0}
{0.000111237335659098} {0.000117182018584572} {0.000116735325718764} {0.000117244853754528}
{0.0576796568930149} {0.0582110174000263} {0.0582080148160458} {0.0582110174000263}
{2.0} {} {2.0} {2.0}
{5.65685415267944} {5.65413522720337} {5.65685415267944} {5.65685415267944}
}
do_execsql_test vector-1-conversion-simple {
SELECT hex(vector32('[]'));
SELECT hex(vector64(vector32('[]')));
} {
{}
02
}
do_execsql_test vector-1-conversion-to-f32 {
SELECT vector_extract(vector32(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector32(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector32(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector32(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector32(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
} {
{[-1,-1,1,-1,1,-1,1]} 000080BF000080BF0000803F000080BF0000803F000080BF0000803F
{[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} BD3786B5000000000000807FF90215D0FFE6DB2E000000000000C03F
{[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} BD3786B5000000000000807FF90215D0FFE6DB2E000000000000C03F
{[-1.01328e-06,0,Inf,-Inf,0,0,1.5]} 000088B5000000000000807F000080FF00000000000000000000C03F
{[-9.98378e-07,0,Inf,-9.99922e+09,9.95897e-11,0,1.5]} 000086B5000000000000807F000015D00000DB2E000000000000C03F
}
do_execsql_test vector-1-conversion-to-f64 {
SELECT vector_extract(vector64(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector64(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector64(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector64(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector64(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
} {
{[-1,-1,1,-1,1,-1,1]} 000000000000F0BF000000000000F0BF000000000000F03F000000000000F0BF000000000000F03F000000000000F0BF000000000000F03F02
{[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} 000000A0F7C6B0BE0000000000000000000000000000F07F000000205FA002C2000000E0DF7CDB3D0000000000000000000000000000F83F02
{[-1e-06,1e-100,1e+100,-1e+10,1e-10,0,1.5]} 8DEDB5A0F7C6B0BE30058EE42EFF2B2B7DC39425AD49B254000000205FA002C2BBBDD7D9DF7CDB3D0000000000000000000000000000F83F02
{[-1.01328e-06,0,Inf,-Inf,0,0,1.5]} 000000000000B1BE0000000000000000000000000000F07F000000000000F0FF00000000000000000000000000000000000000000000F83F02
{[-9.98378e-07,0,Inf,-9.99922e+09,9.95897e-11,0,1.5]} 0000000000C0B0BE0000000000000000000000000000F07F0000000000A002C2000000000060DB3D0000000000000000000000000000F83F02
}
do_execsql_test vector-1-conversion-to-f1bit {
SELECT vector_extract(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector1bit(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
SELECT vector_extract(vector1bit(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vectorb16('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]')));
} {
{[-1,-1,1,-1,1,-1,1]} 540903
{[-1,-1,1,-1,1,-1,1]} 540903
{[-1,1,1,-1,1,-1,1]} 560903
{[-1,-1,1,-1,-1,-1,1]} 440903
{[-1,-1,1,-1,1,-1,1]} 540903
}
do_execsql_test vector-1-conversion-to-f16 {
SELECT vector_extract(vector16(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector16(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector16(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector16(vector16('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vector16('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector16(vectorb16('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vectorb16('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
} {
{[-1,-1,1,1,1,1,1,1,1,1]} 00BC00BC003C003C003C003C003C003C003C003C05
{[-20,-35.4375,1,1.5,2,3,10,100,105,110]} 00CD6ED0003C003E00400042004940569056E05605
{[-20,-35.4375,1,1.5,2,3,10,100,105,110]} 00CD6ED0003C003E00400042004940569056E05605
{[-20,-35.4375,1,1.5,2,3,10,100,105,110]} 00CD6ED0003C003E00400042004940569056E05605
{[-20,-35.25,1,1.5,2,3,10,100,105,110]} 00CD68D0003C003E00400042004940569056E05605
}
do_execsql_test vector-1-conversion-f8 {
-- slightly more tests because f8 is loosy compression and it's better to pick values accordingly to algorithm internasl
SELECT vector_extract(vector8(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector8(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector8(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector8(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector8(vector16('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector16('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector8(vectorb16('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vectorb16('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector1bit(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector1bit(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector32(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector32(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector64(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector64(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vector16(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector16(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
SELECT vector_extract(vectorb16(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vectorb16(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]')));
} {
{[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204
{[-1,-1,1,1,1,1,1,1,1,1]} 0000FFFFFFFFFFFFFFFF00008180003C000080BF000204
{[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204
{[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204
{[-20.0382,-35.4375,1.06446,1.6348,2.20515,2.77549,10.1899,99.7338,104.867,110]} 1B004041424350EDF6FF00000202123F00C00DC2000204
{[-19.8706,-35.25,1.2049,1.77451,1.77451,2.91373,9.74902,99.7471,104.874,110]} 1B00404141434FEDF6FF0000D2D1113F00000DC2000204
{[-1,-1,1,1,1,1,1,1,1,1]} FC03001603
{[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} E152A0C18FC20DC20003883F6004D13FD0020D408083314008032341A277C742D0BBD1420000DC42
{[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 000000205C0A34C0000000E051B841C0000000006000F13F000000008C20FA3F000000005AA001400000000070300640000000006160244000000040F4EE5840000000007A375A400000000000805B4002
{[-20.0469,-35.4375,1.0625,1.63281,2.20313,2.77344,10.1875,99.75,104.875,110]} 03CD6ED0403C883E68408C4118493C568E56E05605
{[-20,-35.25,1.0625,1.63281,2.20313,2.76563,10.1875,99.5,104.5,110]} A0C10DC2883FD13F0D4031402341C742D142DC4206
}
proc error_messages {sql} {
set ret ""
set stmt [sqlite3_prepare db $sql -1 dummy]
sqlite3_step $stmt
sqlite3_finalize $stmt
set ret [sqlite3_errmsg db]
}
do_test vector-1-func-errors {
set ret [list]
lappend ret [error_messages {SELECT vector(1.2)}]
lappend ret [error_messages {SELECT vector(10)}]
lappend ret [error_messages {SELECT vector(NULL)}]
lappend ret [error_messages {SELECT vector('')}]
lappend ret [error_messages {SELECT vector('test')}]
lappend ret [error_messages {SELECT vector('[1]]')}]
lappend ret [error_messages {SELECT vector('[[1]')}]
lappend ret [error_messages {SELECT vector('[1, 2, 1.1.1, 4]')}]
lappend ret [error_messages {SELECT vector('[1.2')}]
lappend ret [error_messages {SELECT vector(x'0000000000')}]
lappend ret [error_messages {SELECT vector_distance_cos('[1,2,3]', '[1,2]')}]
lappend ret [error_messages {SELECT vector_distance_cos(vector32('[1,2,3]'), vector64('[1,2,3]'))}]
lappend ret [error_messages {SELECT vector_distance_l2(vector1bit('[1,2,2,3,4,1,5]'), vector1bit('[2,3,1,-1,2,4,5]'))}]
} [list {*}{
{vector: unexpected value type: got FLOAT, expected TEXT or BLOB}
{vector: unexpected value type: got INTEGER, expected TEXT or BLOB}
{vector: unexpected value type: got NULL, expected TEXT or BLOB}
{vector: must start with '['}
{vector: must start with '['}
{vector: non-space symbols after closing ']' are forbidden}
{vector: invalid float at position 0: '[1'}
{vector: invalid float at position 2: '1.1.1'}
{vector: must end with ']'}
{vector: unexpected binary type: 0}
{vector_distance: vectors must have the same length: 3 != 2}
{vector_distance: vectors must have the same type: 1 != 2}
{vector_distance: l2 distance is not supported for float1bit vectors}
}]