0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2025-05-29 09:33:21 +00:00

float16 implementation

This commit is contained in:
Nikita Sivukhin
2024-08-19 13:37:37 +04:00
committed by Sivukhin Nikita
parent 59f189e0d1
commit f8128d27e6
8 changed files with 276 additions and 1 deletions

@ -62,3 +62,5 @@ libsql
/crates/target/
/has_tclsh*
/libsql.wasm
test_libsql_f16_table.h
test_libsql_f16

@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo \
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
vdbetrace.lo vdbevtab.lo \
@ -304,6 +304,7 @@ SRC = \
$(TOP)/src/vector.c \
$(TOP)/src/vectorInt.h \
$(TOP)/src/vectorfloat1bit.c \
$(TOP)/src/vectorfloat16.c \
$(TOP)/src/vectorfloat32.c \
$(TOP)/src/vectorfloat64.c \
$(TOP)/src/vectorfloat8.c \
@ -1143,6 +1144,9 @@ vector.lo: $(TOP)/src/vector.c $(HDR)
vectorfloat1bit.lo: $(TOP)/src/vectorfloat1bit.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat1bit.c
vectorfloat16.lo: $(TOP)/src/vectorfloat16.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat16.c
vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c

@ -0,0 +1,28 @@
/*
* BUILD: cc test_libsql_diskann.c -I ../ -L ../.libs -llibsql -o test_libsql_diskann
* RUN: LD_LIBRARY_PATH=../.libs ./test_libsql_diskann
*/
#include "assert.h"
#include "stdbool.h"
#include "stdarg.h"
#include "stddef.h"
#include "vectorfloat16.c"
#include "test_libsql_f16_table.h"
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#define ensure(condition, ...) { if (!(condition)) { eprintf(__VA_ARGS__); exit(1); } }
int main() {
for(int i = 0; i < 65536; i++){
u32 expected = F16ToF32[i];
float actual = vectorF16ToFloat(i);
u32 actual_u32 = *((u32*)&actual);
ensure(expected == actual_u32, "conversion from %x failed: %f != %f (%x != %x)", i, *(float*)&expected, *(float*)&actual_u32, expected, actual_u32);
}
for(int i = 0; i < 65536; i++){
u16 expected = F32ToF16[i];
u16 actual = vectorF16FromFloat(*(float*)&F32[i]);
ensure(expected == actual, "conversion from %x (%f, it=%d) failed: %x != %x", F32[i], *(float*)&F32[i], i, expected, actual);
}
}

@ -0,0 +1,42 @@
import random
import struct
import numpy as np
u32_list = [random.randint(0, 2**32) for _ in range(65536)]
print("""
u32 F32[65536] = {
""")
for i, x in enumerate(u32_list):
if i % 8 == 0: print(" ", end='');
print('{:>10}u, '.format(x), end='')
if i % 8 == 7: print()
print("};")
print("""
u16 F32ToF16[65536] = {
""")
for i, x in enumerate(u32_list):
if i % 8 == 0: print(" ", end='');
u32_bytes = struct.pack('<I', x)
f32 = np.float16(struct.unpack('<f', u32_bytes)[0])
f16_bytes = struct.pack('<e', f32)
u16 = struct.unpack('<H', f16_bytes)[0]
print('{:>10}, '.format(u16), end='')
if i % 8 == 7: print()
print("};")
print("""
u32 F16ToF32[65536] = {
""")
for x in range(65536):
if x % 8 == 0: print(" ", end='');
u16_bytes = struct.pack('<H', x)
f16 = struct.unpack('<e', u16_bytes)[0]
f32_bytes = struct.pack('<f', f16)
u32 = struct.unpack('<I', f32_bytes)[0]
print('{:>10}u, '.format(u32), end='')
if x % 8 == 7: print()
print("};")

@ -45,6 +45,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
return (dims + 7) / 8;
case VECTOR_TYPE_FLOAT8:
return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */;
case VECTOR_TYPE_FLOAT16:
return dims * sizeof(u16);
default:
assert(0);
}

@ -53,6 +53,7 @@ typedef u32 VectorDims;
#define VECTOR_TYPE_FLOAT64 2
#define VECTOR_TYPE_FLOAT1BIT 3
#define VECTOR_TYPE_FLOAT8 4
#define VECTOR_TYPE_FLOAT16 5
#define VECTOR_FLAGS_STATIC 1
@ -80,6 +81,7 @@ void vectorInit(Vector *, VectorType, VectorDims, void *);
*/
void vectorDump (const Vector *v);
void vectorF8Dump (const Vector *v);
void vectorF16Dump (const Vector *v);
void vectorF32Dump (const Vector *v);
void vectorF64Dump (const Vector *v);
void vector1BitDump(const Vector *v);
@ -99,6 +101,7 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *);
*/
void vectorSerializeToBlob (const Vector *, unsigned char *, size_t);
void vectorF8SerializeToBlob (const Vector *, unsigned char *, size_t);
void vectorF16SerializeToBlob (const Vector *, unsigned char *, size_t);
void vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t);
void vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t);
void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t);
@ -108,6 +111,7 @@ void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t);
*/
float vectorDistanceCos (const Vector *, const Vector *);
float vectorF8DistanceCos (const Vector *, const Vector *);
float vectorF16DistanceCos (const Vector *, const Vector *);
float vectorF32DistanceCos (const Vector *, const Vector *);
double vectorF64DistanceCos(const Vector *, const Vector *);
@ -121,6 +125,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *);
*/
float vectorDistanceL2 (const Vector *, const Vector *);
float vectorF8DistanceL2 (const Vector *, const Vector *);
float vectorF16DistanceL2 (const Vector *, const Vector *);
float vectorF32DistanceL2 (const Vector *, const Vector *);
double vectorF64DistanceL2(const Vector *, const Vector *);
@ -137,6 +142,7 @@ void vectorSerializeWithMeta(sqlite3_context *, const Vector *);
int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **);
void vectorF8DeserializeFromBlob (Vector *, const unsigned char *, size_t);
void vectorF16DeserializeFromBlob (Vector *, const unsigned char *, size_t);
void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t);
void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t);
void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t);

@ -0,0 +1,190 @@
/*
** 2024-07-04
**
** Copyright 2024 the libSQL authors
**
** Permission is hereby granted, free of charge, to any person obtaining a copy of
** this software and associated documentation files (the "Software"), to deal in
** the Software without restriction, including without limitation the rights to
** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
** the Software, and to permit persons to whom the Software is furnished to do so,
** subject to the following conditions:
**
** The above copyright notice and this permission notice shall be included in all
** copies or substantial portions of the Software.
**
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
**
******************************************************************************
**
** 16-bit (FLOAT16) floating point vector format utilities.
*/
#ifndef SQLITE_OMIT_VECTOR
#include "sqliteInt.h"
#include "vectorInt.h"
#include <math.h>
/**************************************************************************
** Utility routines for vector serialization and deserialization
**************************************************************************/
// f32: [fffffffffffffffffffffffeeeeeeees]
// 01234567890123456789012345678901
// f16: [ffffffffffeeeees]
// 0123456789012345
static float vectorF16ToFloat(u16 f16){
u32 f32;
// sng: [0000000000000000000000000000000s]
u32 sgn = ((u32)f16 & 0x8000) << 16;
int expBits = (f16 >> 10) & 0x1f;
int exp = expBits - 15; // 15 is exp bias for f16
u32 mnt = ((u32)f16 & 0x3ff);
u32 mntNonZero = !!mnt;
if( exp == 16 ){ // NaN or +/- Infinity
exp = 128, mnt = mntNonZero << 22; // set mnt high bit to represent NaN if it was NaN in f16
}else if( exp == -15 && mnt == 0 ){ // zero
exp = -127, mnt = 0;
}else if( exp == -15 ){ // denormalized value
// shift mantissa until we get 1 as a high bit
exp++;
while( (mnt & 0x400) == 0 ){
mnt <<= 1;
exp--;
}
// then reset high bit as this will be normal value (not denormalized) in f32
mnt &= 0x3ff;
mnt <<= 13;
}else{
mnt <<= 13;
}
f32 = sgn | ((u32)(exp + 127) << 23) | mnt;
return *((float*)&f32);
}
static u16 vectorF16FromFloat(float f){
u32 i = *((u32*)&f);
// sng: [000000000000000s]
u32 sgn = (i >> 16) & (0x8000);
// expBits: [eeeeeeee]
int expBits = (i >> 23) & (0xff);
int exp = expBits - 127; // 127 is exp bias for f32
// mntBits: [fffffffffffffffffffffff]
u32 mntBits = (i & 0x7fffff);
u32 mntNonZero = !!mntBits;
u32 mnt;
if( exp == 128 ){ // NaN or +/- Infinity
exp = 16, mntBits = mntNonZero << 22; // set mnt high bit to represent NaN if it was NaN in f32
}else if( exp > 15 ){ // just too big numbers for f16
exp = 16, mntBits = 0;
}else if( exp < -14 && exp >= -25 ){ // small value, but we can be represented as denormalized f16
// set high bit to 1 as normally mantissa has form 1.[mnt] but denormalized mantissa has form 0.[mnt]
mntBits = (mntBits | 0x800000) >> (-exp - 14);
exp = -15;
}else if( exp < -24 ){ // very small or denormalized value
exp = -15, mntBits = 0;
}
// round to nearest, ties to even
if( (mntBits & 0x1fff) > (0x1000 - ((mntBits >> 13) & 1)) ){
mntBits += 0x2000;
}
mnt = mntBits >> 13;
// handle overflow here (note, that overflow can happen only if exp < 16)
return sgn | ((u32)(exp + 15 + (mnt >> 10)) << 10) | (mnt & 0x3ff);
}
void vectorF16Dump(const Vector *pVec){
u16 *elems = pVec->data;
unsigned i;
assert( pVec->type == VECTOR_TYPE_FLOAT16 );
printf("f16: [");
for(i = 0; i < pVec->dims; i++){
printf("%s%f", i == 0 ? "" : ", ", vectorF16ToFloat(elems[i]));
}
printf("]\n");
}
void vectorF16SerializeToBlob(
const Vector *pVector,
unsigned char *pBlob,
size_t nBlobSize
){
float alpha, shift;
assert( pVector->type == VECTOR_TYPE_FLOAT16 );
assert( pVector->dims <= MAX_VECTOR_SZ );
assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) );
memcpy(pBlob, pVector->data, pVector->dims * sizeof(u16));
}
float vectorF16DistanceCos(const Vector *v1, const Vector *v2){
int i;
float dot = 0, norm1 = 0, norm2 = 0;
float value1, value2;
u16 *data1 = v1->data, *data2 = v2->data;
assert( v1->dims == v2->dims );
assert( v1->type == VECTOR_TYPE_FLOAT16 );
assert( v2->type == VECTOR_TYPE_FLOAT16 );
for(i = 0; i < v1->dims; i++){
value1 = vectorF16ToFloat(data1[i]);
value2 = vectorF16ToFloat(data2[i]);
dot += value1*value2;
norm1 += value1*value1;
norm2 += value2*value2;
}
return 1.0 - (dot / sqrt(norm1 * norm2));
}
float vectorF16DistanceL2(const Vector *v1, const Vector *v2){
int i;
float sum = 0;
float value1, value2;
u8 *data1 = v1->data, *data2 = v2->data;
assert( v1->dims == v2->dims );
assert( v1->type == VECTOR_TYPE_FLOAT16 );
assert( v2->type == VECTOR_TYPE_FLOAT16 );
for(i = 0; i < v1->dims; i++){
value1 = vectorF16ToFloat(data1[i]);
value2 = vectorF16ToFloat(data2[i]);
float d = (value1 - value2);
sum += d*d;
}
return sqrt(sum);
}
void vectorF16DeserializeFromBlob(
Vector *pVector,
const unsigned char *pBlob,
size_t nBlobSize
){
assert( pVector->type == VECTOR_TYPE_FLOAT16 );
assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ );
assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) );
memcpy((u8*)pVector->data, (u8*)pBlob, pVector->dims * sizeof(u16));
}
#endif /* !defined(SQLITE_OMIT_VECTOR) */

@ -473,6 +473,7 @@ set flist {
vectorfloat32.c
vectorfloat64.c
vectorfloat8.c
vectorfloat16.c
vectorIndex.c
vectorvtab.c
rtree.c