diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index d32819cd00..8ae0065f20 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -130,6 +130,33 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){ return 0; } +void vectorMult(Vector *pVector, double k){ + switch (pVector->type) { + case VECTOR_TYPE_FLOAT32: + vectorF32Mult(pVector, k); + break; + case VECTOR_TYPE_FLOAT64: + vectorF64Mult(pVector, k); + break; + default: + assert(0); + } +} + +void vectorAdd(Vector *v1, const Vector *v2){ + assert( pVector1->type == pVector2->type ); + assert( pVector1->dims == pVector2->dims ); + switch (v1->type) { + case VECTOR_TYPE_FLOAT32: + vectorF32Add(v1, v2); + break; + case VECTOR_TYPE_FLOAT64: + vectorF64Add(v1, v2); + break; + default: + assert(0); + } +} const char *sqlite3_type_repr(int type){ switch( type ){ case SQLITE_NULL: @@ -590,6 +617,250 @@ static void vectorDistanceCosFunc( } } +/* +** Implementation of vector_sum(V...) scalar function. +*/ +static void vectorSumFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + char *pzErrMsg = NULL; + Vector *pSum = NULL, *pVector = NULL; + int i; + int typeSum, dimsSum, typeVector, dimsVector; + + if( argc < 1 ){ + return; + } + if( detectVectorParameters(argv[0], 0, &typeSum, &dimsSum, &pzErrMsg) != 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + pSum = vectorContextAlloc(context, typeSum, dimsSum); + if( pSum == NULL ){ + goto out_free; + } + if( vectorParse(argv[0], pSum, &pzErrMsg) < 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + pVector = vectorContextAlloc(context, typeSum, dimsSum); + if( pVector == NULL ){ + goto out_free; + } + for(i = 1; i < argc; i++){ + if( detectVectorParameters(argv[i], 0, &typeVector, &dimsVector, &pzErrMsg) != 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( typeSum != typeVector ){ + pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same type: %d != %d", typeSum, typeVector); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( dimsSum != dimsVector ){ + pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same length: %d != %d", dimsSum, dimsVector); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( vectorParse(argv[i], pVector, &pzErrMsg) < 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + vectorAdd(pSum, pVector); + } + vectorSerialize(context, pSum); +out_free: + if( pSum != NULL ){ + vectorFree(pSum); + } + if( pVector != NULL ){ + vectorFree(pVector); + } +} + +struct VectorSumCtx { + i64 count; + Vector *pSum; + Vector *pVector; +}; + +static void vectorSumAdd( + sqlite3_context *context, + int argc, + sqlite3_value **argv, + double k +){ + char *pzErrMsg; + struct VectorSumCtx *p; + int type, dims; + assert( argc == 1 ); + UNUSED_PARAMETER(argc); + p = sqlite3_aggregate_context(context, sizeof(*p)); + if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + if( p->count == 0 ){ + p->pSum = vectorContextAlloc(context, type, dims); + if( p->pSum == NULL ){ + return; + } + } + if( p->pSum->type != type ){ + pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same type: %d != %d", p->pSum->type, type); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + if( p->pSum->dims != dims ){ + pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same length: %d != %d", p->pSum->dims, dims); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + if( p->count == 0 ){ + if( vectorParse(argv[0], p->pSum, &pzErrMsg) < 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + }else{ + vectorMult(p->pSum, k); + p->count++; + } + return; + } + if( p->pVector == NULL ){ + p->pVector = vectorContextAlloc(context, type, dims); + if( p->pVector == NULL ){ + return; + } + } + if( vectorParse(argv[0], p->pVector, &pzErrMsg) < 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + vectorMult(p->pVector, k); + vectorAdd(p->pSum, p->pVector); + p->count++; +} + +static void vectorSumEnd(sqlite3_context *context, int freeMem){ + struct VectorSumCtx *p; + p = sqlite3_aggregate_context(context, 0); + if( p && p->count>0 ){ + vectorSerialize(context, p->pSum); + } + if( p && p->pSum != NULL && freeMem ){ + vectorFree(p->pSum); + } + if( p && p->pVector != NULL && freeMem ){ + vectorFree(p->pVector); + } +} + +/* +** Implementation of vector_sum aggregate function (step part) +*/ +static void vectorSumStep(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorSumAdd(context, argc, argv, 1.0); +} + +/* +** Implementation of vector_sum aggregate function (inverse part) +*/ +static void vectorSumInverse(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorSumAdd(context, argc, argv, -1.0); +} + +/* +** Implementation of vector_sum aggregate function (finalize part) +*/ +static void vectorSumFinalize(sqlite3_context *context){ + vectorSumEnd(context, 1); +} + +/* +** Implementation of vector_sum aggregate function (value part) +*/ +static void vectorSumValue(sqlite3_context *context){ + vectorSumEnd(context, 0); +} + +/* +** Implementation of vector_mult(V, k) / vector_mult(k, V) function. +*/ +static void vectorMultFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + char *pzErrMsg; + sqlite3_value *pMultValue = NULL, *pVectorValue = NULL; + int type, dims; + Vector *pVector; + double k; + + assert( argc == 2 ); + + if( sqlite3_value_type(argv[0]) == SQLITE_INTEGER || sqlite3_value_type(argv[0]) == SQLITE_FLOAT ){ + pMultValue = argv[0]; + } + if( sqlite3_value_type(argv[1]) == SQLITE_INTEGER || sqlite3_value_type(argv[1]) == SQLITE_FLOAT ){ + pMultValue = argv[1]; + } + if( sqlite3_value_type(argv[0]) == SQLITE_BLOB || sqlite3_value_type(argv[0]) == SQLITE_TEXT ){ + pVectorValue = argv[0]; + } + if( sqlite3_value_type(argv[1]) == SQLITE_BLOB || sqlite3_value_type(argv[1]) == SQLITE_TEXT ){ + pVectorValue = argv[1]; + } + if( pMultValue == NULL || pVectorValue == NULL ){ + pzErrMsg = sqlite3_mprintf( + "vector_mult: unexpected parameters: got %s and %s, but expected vector-compatible and float-compatible types", + sqlite3_type_repr(sqlite3_value_type(argv[0])), + sqlite3_type_repr(sqlite3_value_type(argv[1])) + ); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + + if( detectVectorParameters(pVectorValue, 0, &type, &dims, &pzErrMsg) != 0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + return; + } + if( sqlite3_value_type(pMultValue) == SQLITE_INTEGER ){ + k = sqlite3_value_int64(pMultValue); + } + if( sqlite3_value_type(pMultValue) == SQLITE_FLOAT ){ + k = sqlite3_value_double(pMultValue); + } + pVector = vectorContextAlloc(context, type, dims); + if( pVector == NULL ){ + return; + } + if( vectorParse(pVectorValue, pVector, &pzErrMsg)<0 ){ + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + + vectorMult(pVector, k); + vectorSerialize(context, pVector); +out_free: + vectorFree(pVector); +} + /* * Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb)); */ @@ -607,7 +878,10 @@ void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), + FUNCTION(vector_sum, -1, 0, 0, vectorSumFunc), + FUNCTION(vector_mult, 2, 0, 0, vectorMultFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), + WAGGREGATE(vector_sum, 1, 0, 0, vectorSumStep, vectorSumFinalize, vectorSumFinalize, vectorSumInverse, SQLITE_FUNC_ANYORDER), FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx), }; diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 8c9138b94f..fc1e6de73d 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -86,6 +86,19 @@ float vectorDistanceL2 (const Vector *, const Vector *); float vectorF32DistanceL2 (const Vector *, const Vector *); double vectorF64DistanceL2(const Vector *, const Vector *); +/* + * Multiply vector in-place by floating point constant k +*/ +void vectorMult (Vector *, double); +void vectorF32Mult(Vector *, double); +void vectorF64Mult(Vector *, double); + +/* + * Add second vector argument to first vector in-place +*/ +void vectorAdd (Vector *, const Vector *); +void vectorF32Add(Vector *, const Vector *); +void vectorF64Add(Vector *, const Vector *); /* * Serializes vector to the sqlite_blob in little-endian format according to the IEEE-754 standard * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 8aeae2eb23..a77fa97b85 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -215,6 +215,31 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } +void vectorF32Mult(Vector *v, double k){ + float *e = v->data; + int i; + + assert( v->type == VECTOR_TYPE_FLOAT32 ); + + for(i = 0; i < v->dims; i++){ + e[i] *= k; + } +} + +void vectorF32Add(Vector *v1, const Vector *v2){ + float *e1 = v1->data; + float *e2 = v2->data; + int i; + + assert( v1->type == VECTOR_TYPE_FLOAT32 ); + assert( v1->type == v2->type ); + assert( v1->dims == v2->dims ); + + for(i = 0; i < v1->dims; i++){ + e1[i] += e2[i]; + } +} + void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ pVector->dims = nBlobSize / sizeof(float); pVector->data = (void*)pBlob; diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index ced2be1843..45a74e1779 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -222,6 +222,31 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } +void vectorF64Mult(Vector *v, double k){ + double *e = v->data; + int i; + + assert( v->type == VECTOR_TYPE_FLOAT64 ); + + for(i = 0; i < v->dims; i++){ + e[i] *= k; + } +} + +void vectorF64Add(Vector *v1, const Vector *v2){ + double *e1 = v1->data; + double *e2 = v2->data; + int i; + + assert( v1->type == VECTOR_TYPE_FLOAT64 ); + assert( v1->type == v2->type ); + assert( v1->dims == v2->dims ); + + for(i = 0; i < v1->dims; i++){ + e1[i] += e2[i]; + } +} + void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ pVector->dims = nBlobSize / sizeof(double); pVector->data = (void*)pBlob; diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index cf91a7fa18..c4e41335ba 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -50,6 +50,16 @@ do_execsql_test vector-1-func-valid { SELECT vector_distance_cos('[1,1]', '[-1,-1]'); SELECT vector_distance_cos('[1,1]', '[-1,1]'); SELECT vector_distance_cos('[1,2]', '[2,1]'); + SELECT + vector_extract(vector_mult(-0.5, '[2,-4]')), + vector_extract(vector_mult('[2,-4]', -0.5)), + vector_extract(vector_mult(-1, '[2,-4]')), + vector_extract(vector_mult(-0.5, vector('[2,-4]'))); + SELECT vector_extract(vector_sum( + '[1,2]', + vector('[2,3]'), + vector_mult(2, vector('[3,4]')) + )); } { {[]} {[]} @@ -65,6 +75,26 @@ do_execsql_test vector-1-func-valid { {2.0} {1.0} {0.200000002980232} + {[-1,2]} {[-1,2]} {[-2,4]} {[-1,2]} + {[9,13]} +} + +do_execsql_test vector-1-agg-valid { + CREATE TABLE t_vec_agg ( a FLOAT32(2), b FLOAT32(2), c FLOAT32(2), k FLOAT ); + INSERT INTO t_vec_agg VALUES (vector('[1,2]'), '[2,3]', '[3,4]', 1); + INSERT INTO t_vec_agg VALUES (vector('[2,3]'), '[3,4]', vector('[4,5]'), 2); + SELECT vector_extract(vector_sum(a)) FROM t_vec_agg; + SELECT vector_extract(vector_sum(b)) FROM t_vec_agg; + SELECT vector_extract(vector_sum(c)) FROM t_vec_agg; + INSERT INTO t_vec_agg VALUES (vector('[3,4]'), '[4,5]', '6', 3); + SELECT vector_extract(vector_sum(a)) FROM t_vec_agg; + SELECT vector_extract(vector_sum(vector_mult(b, k))) FROM t_vec_agg; +} { + {[3,5]} + {[5,7]} + {[7,9]} + {[6,9]} + {[20,26]} } proc error_messages {sql} { @@ -89,6 +119,10 @@ do_test vector-1-func-errors { lappend ret [error_messages {SELECT vector(x'0000000000')}] lappend ret [error_messages {SELECT vector_distance_cos('[1,2,3]', '[1,2]')}] lappend ret [error_messages {SELECT vector_distance_cos(vector32('[1,2,3]'), vector64('[1,2,3]'))}] + lappend ret [error_messages {SELECT vector_sum(vector32('[1,2,3]'), vector64('[1,2,3]'))}] + lappend ret [error_messages {SELECT vector_sum(vector('[1,2,3]'), vector('[1,2]'))}] + lappend ret [error_messages {SELECT vector_sum(1, vector64('[1,2,3]'))}] + lappend ret [error_messages {SELECT vector_mult(1, 2)}] } [list {*}{ {vector: unexpected value type: got FLOAT, expected TEXT or BLOB} {vector: unexpected value type: got INTEGER, expected TEXT or BLOB} @@ -102,4 +136,8 @@ do_test vector-1-func-errors { {vector: unexpected binary type: got 0, expected 1 or 2} {vector_distance_cos: vectors must have the same length: 3 != 2} {vector_distance_cos: vectors must have the same type: 1 != 2} + {vector_sum: vectors must have the same type: 1 != 2} + {vector_sum: vectors must have the same length: 3 != 2} + {vector: unexpected value type: got INTEGER, expected TEXT or BLOB} + {vector_mult: unexpected parameters: got INTEGER and INTEGER, but expected vector-compatible and float-compatible types} }]