From f510359793bd24d5427ef67ccaa5ab0caf6a95f2 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Thu, 1 Aug 2024 11:47:24 +0200 Subject: [PATCH 001/121] Breaking change: Make Rows and Row API more consistent. A few notes: I went the path of least resistance also assuming it would break fewer folks, i.e. make Row more like Rows and thus usize -> i32. Arguably, an unsigned type might be more appropriate both for length and indexes. I understand that the i32 stems from the sqlite bindings, which returns/accepts c_ints. Yet, Statement and Row made the jump to an unsigned, probably drawing the same conclusion, whereas Rows preserved its proximity to the c-bindings. This is probably an artifact? Also going on a limb, mapping c_int -> i32 is already a non-portable choice, with precision for c_int being platform dependent. --- libsql/src/de.rs | 2 +- libsql/src/hrana/mod.rs | 29 ++++++++++++++++++---------- libsql/src/local/impls.rs | 18 ++++++++++------- libsql/src/local/rows.rs | 18 ++++++++++------- libsql/src/local/statement.rs | 21 ++++++++++---------- libsql/src/replication/connection.rs | 18 ++++++++++------- libsql/src/rows.rs | 21 +++++++++----------- 7 files changed, 72 insertions(+), 55 deletions(-) diff --git a/libsql/src/de.rs b/libsql/src/de.rs index 63ee71f598..44f231c134 100644 --- a/libsql/src/de.rs +++ b/libsql/src/de.rs @@ -68,7 +68,7 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> { visitor.visit_map(RowMapAccess { row: self.row, - idx: 0..self.row.inner.column_count(), + idx: 0..(self.row.inner.column_count() as usize), value: None, }) } diff --git a/libsql/src/hrana/mod.rs b/libsql/src/hrana/mod.rs index 9befe549de..4a6fd0c63a 100644 --- a/libsql/src/hrana/mod.rs +++ b/libsql/src/hrana/mod.rs @@ -24,7 +24,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::rows::{RowInner, RowsInner}; +use super::rows::{ColumnsInner, RowInner, RowsInner}; pub(crate) type Result = std::result::Result; @@ -261,7 +261,12 @@ where async fn next(&mut self) -> crate::Result> { self.next().await } +} +impl ColumnsInner for HranaRows +where + S: Stream> + Send + Sync + Unpin, +{ fn column_count(&self) -> i32 { self.column_count() } @@ -303,13 +308,6 @@ impl RowInner for Row { Ok(into_value2(v)) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.cols - .get(idx as usize) - .and_then(|c| c.name.as_ref()) - .map(|s| s.as_str()) - } - fn column_str(&self, idx: i32) -> crate::Result<&str> { if let Some(value) = self.inner.get(idx as usize) { if let proto::Value::Text { value } = value { @@ -321,6 +319,15 @@ impl RowInner for Row { Err(crate::Error::ColumnNotFound(idx)) } } +} + +impl ColumnsInner for Row { + fn column_name(&self, idx: i32) -> Option<&str> { + self.cols + .get(idx as usize) + .and_then(|c| c.name.as_ref()) + .map(|s| s.as_str()) + } fn column_type(&self, idx: i32) -> crate::Result { if let Some(value) = self.inner.get(idx as usize) { @@ -337,8 +344,8 @@ impl RowInner for Row { } } - fn column_count(&self) -> usize { - self.cols.len() + fn column_count(&self) -> i32 { + self.cols.len() as i32 } } @@ -417,7 +424,9 @@ impl RowsInner for StmtResultRows { inner: Box::new(row), })) } +} +impl ColumnsInner for StmtResultRows { fn column_count(&self) -> i32 { self.cols.len() as i32 } diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 8a9a5f440e..2338317a34 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -5,7 +5,7 @@ use crate::connection::BatchRows; use crate::{ connection::Conn, params::Params, - rows::{RowInner, RowsInner}, + rows::{ColumnsInner, RowInner, RowsInner}, statement::Stmt, transaction::Tx, Column, Connection, Result, Row, Rows, Statement, Transaction, TransactionBehavior, Value, @@ -159,7 +159,9 @@ impl RowsInner for LibsqlRows { Ok(row) } +} +impl ColumnsInner for LibsqlRows { fn column_count(&self) -> i32 { self.0.column_count() } @@ -180,20 +182,22 @@ impl RowInner for LibsqlRow { self.0.get_value(idx) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.0.column_name(idx) - } - fn column_str(&self, idx: i32) -> Result<&str> { self.0.get::<&str>(idx) } +} + +impl ColumnsInner for LibsqlRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.0.column_name(idx) + } fn column_type(&self, idx: i32) -> Result { self.0.column_type(idx).map(ValueType::from) } - fn column_count(&self) -> usize { - self.0.stmt.column_count() + fn column_count(&self) -> i32 { + self.0.stmt.column_count() as i32 } } diff --git a/libsql/src/local/rows.rs b/libsql/src/local/rows.rs index 7eb52d461b..4d4e622c75 100644 --- a/libsql/src/local/rows.rs +++ b/libsql/src/local/rows.rs @@ -1,6 +1,6 @@ use crate::local::{Connection, Statement}; use crate::params::Params; -use crate::rows::{RowInner, RowsInner}; +use crate::rows::{ColumnsInner, RowInner, RowsInner}; use crate::{errors, Error, Result}; use crate::{Value, ValueRef}; use libsql_sys::ValueType; @@ -213,7 +213,9 @@ impl RowsInner for BatchedRows { Ok(None) } } +} +impl ColumnsInner for BatchedRows { fn column_count(&self) -> i32 { self.cols.len() as i32 } @@ -244,10 +246,6 @@ impl RowInner for BatchedRow { .ok_or(Error::InvalidColumnIndex) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.cols.get(idx as usize).map(|c| c.0.as_str()) - } - fn column_str(&self, idx: i32) -> Result<&str> { self.row .get(idx as usize) @@ -258,9 +256,15 @@ impl RowInner for BatchedRow { .ok_or(Error::InvalidColumnType) }) } +} + +impl ColumnsInner for BatchedRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.cols.get(idx as usize).map(|c| c.0.as_str()) + } - fn column_count(&self) -> usize { - self.cols.len() + fn column_count(&self) -> i32 { + self.cols.len() as i32 } fn column_type(&self, idx: i32) -> Result { diff --git a/libsql/src/local/statement.rs b/libsql/src/local/statement.rs index 70116a152e..c28a66f18f 100644 --- a/libsql/src/local/statement.rs +++ b/libsql/src/local/statement.rs @@ -250,15 +250,15 @@ impl Statement { /// sure that current statement has already been stepped once before /// calling this method. pub fn column_names(&self) -> Vec<&str> { - let n = self.column_count(); - let mut cols = Vec::with_capacity(n); - for i in 0..n { - let s = self.column_name(i); - if let Some(s) = s { - cols.push(s); - } - } - cols + let n = self.column_count(); + let mut cols = Vec::with_capacity(n); + for i in 0..n { + let s = self.column_name(i); + if let Some(s) = s { + cols.push(s); + } + } + cols } /// Return the number of columns in the result set returned by the prepared @@ -314,12 +314,11 @@ impl Statement { /// the specified `name`. pub fn column_index(&self, name: &str) -> Result { let bytes = name.as_bytes(); - let n = self.column_count() as i32; + let n = self.column_count(); for i in 0..n { // Note: `column_name` is only fallible if `i` is out of bounds, // which we've already checked. let col_name = self - .inner .column_name(i) .ok_or_else(|| Error::InvalidColumnName(name.to_string()))?; if bytes.eq_ignore_ascii_case(col_name.as_bytes()) { diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index c82f523559..c720838798 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -11,7 +11,7 @@ use parking_lot::Mutex; use crate::parser; use crate::parser::StmtKind; -use crate::rows::{RowInner, RowsInner}; +use crate::rows::{ColumnsInner, RowInner, RowsInner}; use crate::statement::Stmt; use crate::transaction::Tx; use crate::{ @@ -780,7 +780,9 @@ impl RowsInner for RemoteRows { let row = RemoteRow(values, self.0.column_descriptions.clone()); Ok(Some(row).map(Box::new).map(|inner| Row { inner })) } +} +impl ColumnsInner for RemoteRows { fn column_count(&self) -> i32 { self.0.column_descriptions.len() as i32 } @@ -813,10 +815,6 @@ impl RowInner for RemoteRow { .ok_or(Error::InvalidColumnIndex) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.1.get(idx as usize).map(|s| s.name.as_str()) - } - fn column_str(&self, idx: i32) -> Result<&str> { let value = self.0.get(idx as usize).ok_or(Error::InvalidColumnIndex)?; @@ -825,6 +823,12 @@ impl RowInner for RemoteRow { _ => Err(Error::InvalidColumnType), } } +} + +impl ColumnsInner for RemoteRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.1.get(idx as usize).map(|s| s.name.as_str()) + } fn column_type(&self, idx: i32) -> Result { let col = self.1.get(idx as usize).unwrap(); @@ -835,8 +839,8 @@ impl RowInner for RemoteRow { .ok_or(Error::InvalidColumnType) } - fn column_count(&self) -> usize { - self.1.len() + fn column_count(&self) -> i32 { + self.1.len() as i32 } } diff --git a/libsql/src/rows.rs b/libsql/src/rows.rs index b97aeac203..a10d82b827 100644 --- a/libsql/src/rows.rs +++ b/libsql/src/rows.rs @@ -38,14 +38,8 @@ impl Column<'_> { } #[async_trait::async_trait] -pub(crate) trait RowsInner { +pub(crate) trait RowsInner: ColumnsInner { async fn next(&mut self) -> Result>; - - fn column_count(&self) -> i32; - - fn column_name(&self, idx: i32) -> Option<&str>; - - fn column_type(&self, idx: i32) -> Result; } /// A set of rows returned from a connection. @@ -131,7 +125,7 @@ impl Row { } /// Get the count of columns in this set of rows. - pub fn column_count(&self) -> usize { + pub fn column_count(&self) -> i32 { self.inner.column_count() } @@ -284,12 +278,15 @@ where } impl Sealed for Option {} -pub(crate) trait RowInner: fmt::Debug { - fn column_value(&self, idx: i32) -> Result; - fn column_str(&self, idx: i32) -> Result<&str>; +pub(crate) trait ColumnsInner { fn column_name(&self, idx: i32) -> Option<&str>; fn column_type(&self, idx: i32) -> Result; - fn column_count(&self) -> usize; + fn column_count(&self) -> i32; +} + +pub(crate) trait RowInner: ColumnsInner + fmt::Debug { + fn column_value(&self, idx: i32) -> Result; + fn column_str(&self, idx: i32) -> Result<&str>; } mod sealed { From a68f042914cd637904123c7375e297a00ac5cecb Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 2 Aug 2024 09:26:28 -0700 Subject: [PATCH 002/121] libsql: release v0.5.0 --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- libsql-ffi/Cargo.toml | 2 +- libsql-replication/Cargo.toml | 4 ++-- libsql-sys/Cargo.toml | 4 ++-- libsql/Cargo.toml | 8 ++++---- vendored/rusqlite/Cargo.toml | 4 ++-- vendored/sqlite3-parser/Cargo.toml | 2 +- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 04ae728065..17cfc0e090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3430,7 +3430,7 @@ dependencies = [ [[package]] name = "libsql" -version = "0.5.0-alpha.2" +version = "0.5.0" dependencies = [ "anyhow", "async-stream", @@ -3488,7 +3488,7 @@ dependencies = [ [[package]] name = "libsql-ffi" -version = "0.3.0" +version = "0.4.0" dependencies = [ "bindgen 0.66.1", "cc", @@ -3508,7 +3508,7 @@ dependencies = [ [[package]] name = "libsql-rusqlite" -version = "0.31.0" +version = "0.32.0" dependencies = [ "bencher", "bitflags 2.6.0", @@ -3631,7 +3631,7 @@ dependencies = [ [[package]] name = "libsql-sqlite3-parser" -version = "0.12.0" +version = "0.13.0" dependencies = [ "bitflags 2.6.0", "cc", @@ -3687,7 +3687,7 @@ dependencies = [ [[package]] name = "libsql-sys" -version = "0.6.0" +version = "0.7.0" dependencies = [ "bytes", "libsql-ffi", @@ -3768,7 +3768,7 @@ dependencies = [ [[package]] name = "libsql_replication" -version = "0.4.0" +version = "0.5.0" dependencies = [ "aes", "arbitrary", diff --git a/Cargo.toml b/Cargo.toml index 92487ecdd0..9381fb83f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ codegen-units = 1 panic = "unwind" [workspace.dependencies] -rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = "0.31", default-features = false, features = [ +rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = "0.32", default-features = false, features = [ "libsql-experimental", "column_decltype", "load_extension", diff --git a/libsql-ffi/Cargo.toml b/libsql-ffi/Cargo.toml index 9b5cbced11..ef9ade1726 100644 --- a/libsql-ffi/Cargo.toml +++ b/libsql-ffi/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-ffi" -version = "0.3.0" +version = "0.4.0" edition = "2021" build = "build.rs" license = "MIT" diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 56f00d7a7d..d2a9431cba 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql_replication" -version = "0.4.0" +version = "0.5.0" edition = "2021" description = "libSQL replication protocol" repository = "https://github.com/tursodatabase/libsql" @@ -11,7 +11,7 @@ license = "MIT" [dependencies] tonic = { version = "0.11", features = ["tls"] } prost = "0.12" -libsql-sys = { version = "0.6", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } +libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } diff --git a/libsql-sys/Cargo.toml b/libsql-sys/Cargo.toml index 26dd091ea9..8351012d9f 100644 --- a/libsql-sys/Cargo.toml +++ b/libsql-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-sys" -version = "0.6.0" +version = "0.7.0" edition = "2021" license = "MIT" description = "Native bindings to libSQL" @@ -12,7 +12,7 @@ categories = ["external-ffi-bindings"] [dependencies] bytes = "1.5.0" -libsql-ffi = { version = "0.3", path = "../libsql-ffi/" } +libsql-ffi = { version = "0.4", path = "../libsql-ffi/" } once_cell = "1.18.0" rusqlite = { workspace = true, features = ["trace"], optional = true } tracing = "0.1.37" diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index fa89cc68ad..efae2abea3 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql" -version = "0.5.0-alpha.2" +version = "0.5.0" edition = "2021" description = "libSQL library: the main gateway for interacting with the database" repository = "https://github.com/tursodatabase/libsql" @@ -11,7 +11,7 @@ tracing = { version = "0.1.37", default-features = false } thiserror = "1.0.40" futures = { version = "0.3.28", optional = true } -libsql-sys = { version = "0.6", path = "../libsql-sys", optional = true } +libsql-sys = { version = "0.7", path = "../libsql-sys", optional = true } libsql-hrana = { version = "0.2", path = "../libsql-hrana", optional = true } tokio = { version = "1.29.1", features = ["sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } @@ -37,10 +37,10 @@ tower-http = { version = "0.4.4", features = ["trace", "set-header", "util"], op http = { version = "0.2", optional = true } zerocopy = { version = "0.7.28", optional = true } -sqlite3-parser = { package = "libsql-sqlite3-parser", path = "../vendored/sqlite3-parser", version = "0.12", optional = true } +sqlite3-parser = { package = "libsql-sqlite3-parser", path = "../vendored/sqlite3-parser", version = "0.13", optional = true } fallible-iterator = { version = "0.3", optional = true } -libsql_replication = { version = "0.4", path = "../libsql-replication", optional = true } +libsql_replication = { version = "0.5", path = "../libsql-replication", optional = true } async-stream = { version = "0.3.5", optional = true } [dev-dependencies] diff --git a/vendored/rusqlite/Cargo.toml b/vendored/rusqlite/Cargo.toml index 2d332f3279..d9fbcc525e 100644 --- a/vendored/rusqlite/Cargo.toml +++ b/vendored/rusqlite/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "libsql-rusqlite" # Note: Update version in README.md when you change this. -version = "0.31.0" +version = "0.32.0" authors = ["The rusqlite developers"] edition = "2018" description = "Ergonomic wrapper for SQLite (libsql fork)" @@ -109,7 +109,7 @@ fallible-iterator = "0.2" fallible-streaming-iterator = "0.1" uuid = { version = "1.0", optional = true } smallvec = "1.6.1" -libsql-ffi = { version = "0.3", path = "../../libsql-ffi" } +libsql-ffi = { version = "0.4", path = "../../libsql-ffi" } [dev-dependencies] doc-comment = "0.3" diff --git a/vendored/sqlite3-parser/Cargo.toml b/vendored/sqlite3-parser/Cargo.toml index 5ed9e31f4d..0381ac1d99 100644 --- a/vendored/sqlite3-parser/Cargo.toml +++ b/vendored/sqlite3-parser/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-sqlite3-parser" -version = "0.12.0" +version = "0.13.0" edition = "2021" authors = ["gwenn"] description = "SQL parser (as understood by SQLite) (libsql fork)" From 0917f84cde2d272141eaa79d933ed0736fb668e5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 2 Aug 2024 23:54:55 +0200 Subject: [PATCH 003/121] introduce NamespaceConfigurator --- .../src/namespace/configurator/mod.rs | 56 ++++++ .../src/namespace/configurator/primary.rs | 128 ++++++++++++ .../src/namespace/configurator/replica.rs | 190 ++++++++++++++++++ libsql-server/src/namespace/mod.rs | 17 +- 4 files changed, 383 insertions(+), 8 deletions(-) create mode 100644 libsql-server/src/namespace/configurator/mod.rs create mode 100644 libsql-server/src/namespace/configurator/primary.rs create mode 100644 libsql-server/src/namespace/configurator/replica.rs diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs new file mode 100644 index 0000000000..0caa1de149 --- /dev/null +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -0,0 +1,56 @@ +use std::pin::Pin; + +use futures::Future; + +use super::broadcasters::BroadcasterHandle; +use super::meta_store::MetaStoreHandle; +use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; + +mod replica; +mod primary; + +type DynConfigurator = Box; + +#[derive(Default)] +struct NamespaceConfigurators { + replica_configurator: Option, + primary_configurator: Option, + schema_configurator: Option, +} + +impl NamespaceConfigurators { + pub fn with_primary( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { + self.primary_configurator = Some(Box::new(c)); + self + } + + pub fn with_replica( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { + self.replica_configurator = Some(Box::new(c)); + self + } + + pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + self.schema_configurator = Some(Box::new(c)); + self + } +} + +pub trait ConfigureNamespace { + fn setup<'a>( + &'a self, + ns_config: &'a NamespaceConfig, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>>; +} diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs new file mode 100644 index 0000000000..f28d288a97 --- /dev/null +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -0,0 +1,128 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::{path::Path, pin::Pin, sync::Arc}; + +use futures::prelude::Future; +use tokio::task::JoinSet; + +use crate::connection::MakeConnection; +use crate::database::{Database, PrimaryDatabase}; +use crate::namespace::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::run_periodic_checkpoint; +use crate::schema::{has_pending_migration_task, setup_migration_table}; + +use super::ConfigureNamespace; + +pub struct PrimaryConfigurator; + +impl ConfigureNamespace for PrimaryConfigurator { + fn setup<'a>( + &'a self, + config: &'a NamespaceConfig, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + _reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + _store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> + { + Box::pin(async move { + let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); + let fresh_namespace = !db_path.try_exists()?; + // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications + match try_new_primary( + config, + name.clone(), + meta_store_handle, + restore_option, + resolve_attach_path, + db_path.clone(), + broadcaster, + ) + .await + { + Ok(this) => Ok(this), + Err(e) if fresh_namespace => { + tracing::error!("an error occured while deleting creating namespace, cleaning..."); + if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { + tracing::error!("failed to remove dirty namespace directory: {e}") + } + Err(e) + } + Err(e) => Err(e), + } + }) + } +} + +#[tracing::instrument(skip_all, fields(namespace))] +async fn try_new_primary( + ns_config: &NamespaceConfig, + namespace: NamespaceName, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + resolve_attach_path: ResolveNamespacePathFn, + db_path: Arc, + broadcaster: BroadcasterHandle, +) -> crate::Result { + let mut join_set = JoinSet::new(); + + tokio::fs::create_dir_all(&db_path).await?; + + let block_writes = Arc::new(AtomicBool::new(false)); + let (connection_maker, wal_wrapper, stats) = Namespace::make_primary_connection_maker( + ns_config, + &meta_store_handle, + &db_path, + &namespace, + restore_option, + block_writes.clone(), + &mut join_set, + resolve_attach_path, + broadcaster, + ) + .await?; + let connection_maker = Arc::new(connection_maker); + + if meta_store_handle.get().shared_schema_name.is_some() { + let block_writes = block_writes.clone(); + let conn = connection_maker.create().await?; + tokio::task::spawn_blocking(move || { + conn.with_raw(|conn| -> crate::Result<()> { + setup_migration_table(conn)?; + if has_pending_migration_task(conn)? { + block_writes.store(true, Ordering::SeqCst); + } + Ok(()) + }) + }) + .await + .unwrap()?; + } + + if let Some(checkpoint_interval) = ns_config.checkpoint_interval { + join_set.spawn(run_periodic_checkpoint( + connection_maker.clone(), + checkpoint_interval, + namespace.clone(), + )); + } + + tracing::debug!("Done making new primary"); + + Ok(Namespace { + tasks: join_set, + db: Database::Primary(PrimaryDatabase { + wal_wrapper, + connection_maker, + block_writes, + }), + name: namespace, + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) +} diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs new file mode 100644 index 0000000000..4d3ca1dadf --- /dev/null +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -0,0 +1,190 @@ +use std::pin::Pin; +use std::sync::Arc; + +use futures::Future; +use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; +use tokio::task::JoinSet; + +use crate::connection::write_proxy::MakeWriteProxyConn; +use crate::connection::MakeConnection; +use crate::database::{Database, ReplicaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{Namespace, RestoreOption}; +use crate::namespace::{ + make_stats, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResetOp, + ResolveNamespacePathFn, +}; +use crate::{DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; + +use super::ConfigureNamespace; + +pub struct ReplicaConfigurator; + +impl ConfigureNamespace for ReplicaConfigurator { + fn setup<'a>( + &'a self, + config: &'a NamespaceConfig, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> + { + Box::pin(async move { + tracing::debug!("creating replica namespace"); + let db_path = config.base_path.join("dbs").join(name.as_str()); + let channel = config.channel.clone().expect("bad replica config"); + let uri = config.uri.clone().expect("bad replica config"); + + let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + let client = crate::replication::replicator_client::Client::new( + name.clone(), + rpc_client, + &db_path, + meta_store_handle.clone(), + store.clone(), + ) + .await?; + let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); + let mut replicator = libsql_replication::replicator::Replicator::new( + client, + db_path.join("data"), + DEFAULT_AUTO_CHECKPOINT, + config.encryption_config.clone(), + ) + .await?; + + tracing::debug!("try perform handshake"); + // force a handshake now, to retrieve the primary's current replication index + match replicator.try_perform_handshake().await { + Err(libsql_replication::replicator::Error::Meta( + libsql_replication::meta::Error::LogIncompatible, + )) => { + tracing::error!( + "trying to replicate incompatible logs, reseting replica and nuking db dir" + ); + std::fs::remove_dir_all(&db_path).unwrap(); + return self.setup( + config, + meta_store_handle, + restore_option, + name, + reset, + resolve_attach_path, + store, + broadcaster, + ) + .await; + } + Err(e) => Err(e)?, + Ok(_) => (), + } + + tracing::debug!("done performing handshake"); + + let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; + + let mut join_set = JoinSet::new(); + let namespace = name.clone(); + join_set.spawn(async move { + use libsql_replication::replicator::Error; + loop { + match replicator.run().await { + err @ Error::Fatal(_) => Err(err)?, + err @ Error::NamespaceDoesntExist => { + tracing::error!("namespace {namespace} doesn't exist, destroying..."); + (reset)(ResetOp::Destroy(namespace.clone())); + Err(err)?; + } + e @ Error::Injector(_) => { + tracing::error!("potential corruption detected while replicating, reseting replica: {e}"); + (reset)(ResetOp::Reset(namespace.clone())); + Err(e)?; + }, + Error::Meta(err) => { + use libsql_replication::meta::Error; + match err { + Error::LogIncompatible => { + tracing::error!("trying to replicate incompatible logs, reseting replica"); + (reset)(ResetOp::Reset(namespace.clone())); + Err(err)?; + } + Error::InvalidMetaFile + | Error::Io(_) + | Error::InvalidLogId + | Error::FailedToCommit(_) + | Error::InvalidReplicationPath + | Error::RequiresCleanDatabase => { + // We retry from last frame index? + tracing::warn!("non-fatal replication error, retrying from last commit index: {err}"); + }, + } + } + e @ (Error::Internal(_) + | Error::Client(_) + | Error::PrimaryHandshakeTimeout + | Error::NeedSnapshot) => { + tracing::warn!("non-fatal replication error, retrying from last commit index: {e}"); + }, + Error::NoHandshake => { + // not strictly necessary, but in case the handshake error goes uncaught, + // we reset the client state. + replicator.client_mut().reset_token(); + } + Error::SnapshotPending => unreachable!(), + } + } + }); + + let stats = make_stats( + &db_path, + &mut join_set, + meta_store_handle.clone(), + config.stats_sender.clone(), + name.clone(), + applied_frame_no_receiver.clone(), + config.encryption_config.clone(), + ) + .await?; + + let connection_maker = MakeWriteProxyConn::new( + db_path.clone(), + config.extensions.clone(), + channel.clone(), + uri.clone(), + stats.clone(), + broadcaster, + meta_store_handle.clone(), + applied_frame_no_receiver, + config.max_response_size, + config.max_total_response_size, + primary_current_replicatio_index, + config.encryption_config.clone(), + resolve_attach_path, + config.make_wal_manager.clone(), + ) + .await? + .throttled( + config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + config.max_total_response_size, + config.max_concurrent_requests, + ); + + Ok(Namespace { + tasks: join_set, + db: Database::Replica(ReplicaDatabase { + connection_maker: Arc::new(connection_maker), + }), + name: name.clone(), + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) + }) + } +} diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 6e48e7f1d8..6a04b11fb8 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,11 +1,3 @@ -pub mod broadcasters; -mod fork; -pub mod meta_store; -mod name; -pub mod replication_wal; -mod schema_lock; -mod store; - use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Weak}; @@ -57,6 +49,15 @@ pub use self::name::NamespaceName; use self::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; pub use self::store::NamespaceStore; +pub mod broadcasters; +mod fork; +pub mod meta_store; +mod name; +pub mod replication_wal; +mod schema_lock; +mod store; +mod configurator; + pub type ResetCb = Box; pub type ResolveNamespacePathFn = Arc crate::Result> + Sync + Send + 'static>; From f9daa9e08f58efdebd52acb4d45435266bec34af Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 00:00:45 +0200 Subject: [PATCH 004/121] add configurators to namespace store --- libsql-server/src/namespace/configurator/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 0caa1de149..a692f75652 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -12,7 +12,7 @@ mod primary; type DynConfigurator = Box; #[derive(Default)] -struct NamespaceConfigurators { +pub(crate) struct NamespaceConfigurators { replica_configurator: Option, primary_configurator: Option, schema_configurator: Option, diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index e0147fc2e8..984a520154 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -19,6 +19,7 @@ use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, Nam use crate::stats::Stats; use super::broadcasters::{BroadcasterHandle, BroadcasterRegistry}; +use super::configurator::NamespaceConfigurators; use super::meta_store::{MetaStore, MetaStoreHandle}; use super::schema_lock::SchemaLocksRegistry; use super::{Namespace, NamespaceConfig, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; @@ -47,6 +48,7 @@ pub struct NamespaceStoreInner { pub config: NamespaceConfig, schema_locks: SchemaLocksRegistry, broadcasters: BroadcasterRegistry, + configurators: NamespaceConfigurators, } impl NamespaceStore { @@ -90,6 +92,7 @@ impl NamespaceStore { config, schema_locks: Default::default(), broadcasters: Default::default(), + configurators: NamespaceConfigurators::default(), }), }) } From 8b377a6e06dfc051bedb68976577adb8de339f40 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 21:18:51 +0200 Subject: [PATCH 005/121] add shcema configurator --- .../src/namespace/configurator/schema.rs | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 libsql-server/src/namespace/configurator/schema.rs diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs new file mode 100644 index 0000000000..864b75239f --- /dev/null +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -0,0 +1,65 @@ +use std::sync::{atomic::AtomicBool, Arc}; + +use futures::prelude::Future; +use tokio::task::JoinSet; + +use crate::database::{Database, SchemaDatabase}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceConfig, NamespaceName, NamespaceStore, + ResetCb, ResolveNamespacePathFn, RestoreOption, +}; +use crate::namespace::broadcasters::BroadcasterHandle; + +use super::ConfigureNamespace; + +pub struct SchemaConfigurator; + +impl ConfigureNamespace for SchemaConfigurator { + fn setup<'a>( + &'a self, + ns_config: &'a NamespaceConfig, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + _reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + _store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let mut join_set = JoinSet::new(); + let db_path = ns_config.base_path.join("dbs").join(name.as_str()); + + tokio::fs::create_dir_all(&db_path).await?; + + let (connection_maker, wal_manager, stats) = Namespace::make_primary_connection_maker( + ns_config, + &db_config, + &db_path, + &name, + restore_option, + Arc::new(AtomicBool::new(false)), // this is always false for schema + &mut join_set, + resolve_attach_path, + broadcaster, + ) + .await?; + + Ok(Namespace { + db: Database::Schema(SchemaDatabase::new( + ns_config.migration_scheduler.clone(), + name.clone(), + connection_maker, + wal_manager, + db_config.clone(), + )), + name: name.clone(), + tasks: join_set, + stats, + db_config_store: db_config.clone(), + path: db_path.into(), + }) + }) + } +} From 978dd7147d0304397f261d918c70c2806eff82a5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 21:19:00 +0200 Subject: [PATCH 006/121] instanciate namesapces from configurators --- .../src/namespace/configurator/mod.rs | 25 +- libsql-server/src/namespace/fork.rs | 26 +- libsql-server/src/namespace/mod.rs | 374 +----------------- libsql-server/src/namespace/store.rs | 76 ++-- 4 files changed, 74 insertions(+), 427 deletions(-) diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index a692f75652..d3cd390b34 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -8,14 +8,19 @@ use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveName mod replica; mod primary; +mod schema; -type DynConfigurator = Box; +pub use replica::ReplicaConfigurator; +pub use primary::PrimaryConfigurator; +pub use schema::SchemaConfigurator; + +type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; #[derive(Default)] pub(crate) struct NamespaceConfigurators { - replica_configurator: Option, - primary_configurator: Option, - schema_configurator: Option, + replica_configurator: Option>, + primary_configurator: Option>, + schema_configurator: Option>, } impl NamespaceConfigurators { @@ -39,6 +44,18 @@ impl NamespaceConfigurators { self.schema_configurator = Some(Box::new(c)); self } + + pub fn configure_schema(&self) -> crate::Result<&DynConfigurator> { + self.schema_configurator.as_deref().ok_or_else(|| todo!()) + } + + pub fn configure_primary(&self) -> crate::Result<&DynConfigurator> { + self.primary_configurator.as_deref().ok_or_else(|| todo!()) + } + + pub fn configure_replica(&self) -> crate::Result<&DynConfigurator> { + self.replica_configurator.as_deref().ok_or_else(|| todo!()) + } } pub trait ConfigureNamespace { diff --git a/libsql-server/src/namespace/fork.rs b/libsql-server/src/namespace/fork.rs index dfa053b43d..f25bf7a9a9 100644 --- a/libsql-server/src/namespace/fork.rs +++ b/libsql-server/src/namespace/fork.rs @@ -12,14 +12,12 @@ use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::time::Duration; use tokio_stream::StreamExt; -use crate::namespace::ResolveNamespacePathFn; use crate::replication::primary::frame_stream::FrameStream; use crate::replication::{LogReadError, ReplicationLogger}; use crate::{BLOCKING_RT, LIBSQL_PAGE_SIZE}; -use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, RestoreOption}; +use super::{NamespaceName, NamespaceStore, RestoreOption}; type Result = crate::Result; @@ -54,16 +52,13 @@ async fn write_frame(frame: &FrameBorrowed, temp_file: &mut tokio::fs::File) -> Ok(()) } -pub struct ForkTask<'a> { +pub struct ForkTask { pub base_path: Arc, pub logger: Arc, pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub ns_config: &'a NamespaceConfig, - pub resolve_attach: ResolveNamespacePathFn, pub store: NamespaceStore, - pub broadcaster: BroadcasterHandle, } pub struct PointInTimeRestore { @@ -71,7 +66,7 @@ pub struct PointInTimeRestore { pub replicator_options: bottomless::replicator::Options, } -impl<'a> ForkTask<'a> { +impl ForkTask { pub async fn fork(self) -> Result { let base_path = self.base_path.clone(); let dest_namespace = self.to_namespace.clone(); @@ -105,18 +100,9 @@ impl<'a> ForkTask<'a> { let dest_path = self.base_path.join("dbs").join(self.to_namespace.as_str()); tokio::fs::rename(temp_dir.path(), dest_path).await?; - Namespace::from_config( - self.ns_config, - self.to_config.clone(), - RestoreOption::Latest, - &self.to_namespace, - Box::new(|_op| {}), - self.resolve_attach.clone(), - self.store.clone(), - self.broadcaster, - ) - .await - .map_err(|e| ForkError::CreateNamespace(Box::new(e))) + self.store.make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) + .await + .map_err(|e| ForkError::CreateNamespace(Box::new(e))) } /// Restores the database state from a local log file. diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 6a04b11fb8..41bb3ab9cc 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,5 +1,5 @@ use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Weak}; use anyhow::{Context as _, Error}; @@ -10,7 +10,6 @@ use chrono::NaiveDateTime; use enclose::enclose; use futures_core::{Future, Stream}; use hyper::Uri; -use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_sys::wal::Sqlite3WalManager; use libsql_sys::EncryptionConfig; use tokio::io::AsyncBufReadExt; @@ -25,20 +24,17 @@ use crate::auth::parse_jwt_keys; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::connection::libsql::{open_conn, MakeLibSqlConn}; -use crate::connection::write_proxy::MakeWriteProxyConn; -use crate::connection::Connection; -use crate::connection::MakeConnection; +use crate::connection::{Connection as _, MakeConnection}; use crate::database::{ - Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, PrimaryDatabase, - ReplicaDatabase, SchemaDatabase, + Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, }; use crate::error::LoadDumpError; use crate::replication::script_backup_manager::ScriptBackupManager; use crate::replication::{FrameNo, ReplicationLogger}; -use crate::schema::{has_pending_migration_task, setup_migration_table, SchedulerHandle}; +use crate::schema::SchedulerHandle; use crate::stats::Stats; use crate::{ - run_periodic_checkpoint, StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, + StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, }; pub use fork::ForkError; @@ -101,54 +97,6 @@ pub struct Namespace { } impl Namespace { - async fn from_config( - ns_config: &NamespaceConfig, - db_config: MetaStoreHandle, - restore_option: RestoreOption, - name: &NamespaceName, - reset: ResetCb, - resolve_attach_path: ResolveNamespacePathFn, - store: NamespaceStore, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - match ns_config.db_kind { - DatabaseKind::Primary if db_config.get().is_shared_schema => { - Self::new_schema( - ns_config, - name.clone(), - db_config, - restore_option, - resolve_attach_path, - broadcaster, - ) - .await - } - DatabaseKind::Primary => { - Self::new_primary( - ns_config, - name.clone(), - db_config, - restore_option, - resolve_attach_path, - broadcaster, - ) - .await - } - DatabaseKind::Replica => { - Self::new_replica( - ns_config, - name.clone(), - db_config, - reset, - resolve_attach_path, - store, - broadcaster, - ) - .await - } - } - } - pub(crate) fn name(&self) -> &NamespaceName { &self.name } @@ -248,40 +196,6 @@ impl Namespace { self.db_config_store.changed() } - async fn new_primary( - config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); - let fresh_namespace = !db_path.try_exists()?; - // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications - match Self::try_new_primary( - config, - name.clone(), - meta_store_handle, - restore_option, - resolve_attach_path, - db_path.clone(), - broadcaster, - ) - .await - { - Ok(this) => Ok(this), - Err(e) if fresh_namespace => { - tracing::error!("an error occured while deleting creating namespace, cleaning..."); - if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { - tracing::error!("failed to remove dirty namespace directory: {e}") - } - Err(e) - } - Err(e) => Err(e), - } - } - #[tracing::instrument(skip_all)] async fn make_primary_connection_maker( ns_config: &NamespaceConfig, @@ -417,237 +331,6 @@ impl Namespace { Ok((connection_maker, wal_wrapper, stats)) } - #[tracing::instrument(skip_all, fields(namespace))] - async fn try_new_primary( - ns_config: &NamespaceConfig, - namespace: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - db_path: Arc, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let mut join_set = JoinSet::new(); - - tokio::fs::create_dir_all(&db_path).await?; - - let block_writes = Arc::new(AtomicBool::new(false)); - let (connection_maker, wal_wrapper, stats) = Self::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &namespace, - restore_option, - block_writes.clone(), - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - let connection_maker = Arc::new(connection_maker); - - if meta_store_handle.get().shared_schema_name.is_some() { - let block_writes = block_writes.clone(); - let conn = connection_maker.create().await?; - tokio::task::spawn_blocking(move || { - conn.with_raw(|conn| -> crate::Result<()> { - setup_migration_table(conn)?; - if has_pending_migration_task(conn)? { - block_writes.store(true, Ordering::SeqCst); - } - Ok(()) - }) - }) - .await - .unwrap()?; - } - - if let Some(checkpoint_interval) = ns_config.checkpoint_interval { - join_set.spawn(run_periodic_checkpoint( - connection_maker.clone(), - checkpoint_interval, - namespace.clone(), - )); - } - - tracing::debug!("Done making new primary"); - - Ok(Self { - tasks: join_set, - db: Database::Primary(PrimaryDatabase { - wal_wrapper, - connection_maker, - block_writes, - }), - name: namespace, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } - - #[tracing::instrument(skip_all, fields(name))] - #[async_recursion::async_recursion] - async fn new_replica( - config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - reset: ResetCb, - resolve_attach_path: ResolveNamespacePathFn, - store: NamespaceStore, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - tracing::debug!("creating replica namespace"); - let db_path = config.base_path.join("dbs").join(name.as_str()); - let channel = config.channel.clone().expect("bad replica config"); - let uri = config.uri.clone().expect("bad replica config"); - - let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); - let client = crate::replication::replicator_client::Client::new( - name.clone(), - rpc_client, - &db_path, - meta_store_handle.clone(), - store.clone(), - ) - .await?; - let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); - let mut replicator = libsql_replication::replicator::Replicator::new( - client, - db_path.join("data"), - DEFAULT_AUTO_CHECKPOINT, - config.encryption_config.clone(), - ) - .await?; - - tracing::debug!("try perform handshake"); - // force a handshake now, to retrieve the primary's current replication index - match replicator.try_perform_handshake().await { - Err(libsql_replication::replicator::Error::Meta( - libsql_replication::meta::Error::LogIncompatible, - )) => { - tracing::error!( - "trying to replicate incompatible logs, reseting replica and nuking db dir" - ); - std::fs::remove_dir_all(&db_path).unwrap(); - return Self::new_replica( - config, - name, - meta_store_handle, - reset, - resolve_attach_path, - store, - broadcaster, - ) - .await; - } - Err(e) => Err(e)?, - Ok(_) => (), - } - - tracing::debug!("done performing handshake"); - - let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; - - let mut join_set = JoinSet::new(); - let namespace = name.clone(); - join_set.spawn(async move { - use libsql_replication::replicator::Error; - loop { - match replicator.run().await { - err @ Error::Fatal(_) => Err(err)?, - err @ Error::NamespaceDoesntExist => { - tracing::error!("namespace {namespace} doesn't exist, destroying..."); - (reset)(ResetOp::Destroy(namespace.clone())); - Err(err)?; - } - e @ Error::Injector(_) => { - tracing::error!("potential corruption detected while replicating, reseting replica: {e}"); - (reset)(ResetOp::Reset(namespace.clone())); - Err(e)?; - }, - Error::Meta(err) => { - use libsql_replication::meta::Error; - match err { - Error::LogIncompatible => { - tracing::error!("trying to replicate incompatible logs, reseting replica"); - (reset)(ResetOp::Reset(namespace.clone())); - Err(err)?; - } - Error::InvalidMetaFile - | Error::Io(_) - | Error::InvalidLogId - | Error::FailedToCommit(_) - | Error::InvalidReplicationPath - | Error::RequiresCleanDatabase => { - // We retry from last frame index? - tracing::warn!("non-fatal replication error, retrying from last commit index: {err}"); - }, - } - } - e @ (Error::Internal(_) - | Error::Client(_) - | Error::PrimaryHandshakeTimeout - | Error::NeedSnapshot) => { - tracing::warn!("non-fatal replication error, retrying from last commit index: {e}"); - }, - Error::NoHandshake => { - // not strictly necessary, but in case the handshake error goes uncaught, - // we reset the client state. - replicator.client_mut().reset_token(); - } - Error::SnapshotPending => unreachable!(), - } - } - }); - - let stats = make_stats( - &db_path, - &mut join_set, - meta_store_handle.clone(), - config.stats_sender.clone(), - name.clone(), - applied_frame_no_receiver.clone(), - config.encryption_config.clone(), - ) - .await?; - - let connection_maker = MakeWriteProxyConn::new( - db_path.clone(), - config.extensions.clone(), - channel.clone(), - uri.clone(), - stats.clone(), - broadcaster, - meta_store_handle.clone(), - applied_frame_no_receiver, - config.max_response_size, - config.max_total_response_size, - primary_current_replicatio_index, - config.encryption_config.clone(), - resolve_attach_path, - config.make_wal_manager.clone(), - ) - .await? - .throttled( - config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - config.max_total_response_size, - config.max_concurrent_requests, - ); - - Ok(Self { - tasks: join_set, - db: Database::Replica(ReplicaDatabase { - connection_maker: Arc::new(connection_maker), - }), - name, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } - async fn fork( ns_config: &NamespaceConfig, from_ns: &Namespace, @@ -655,9 +338,7 @@ impl Namespace { to_ns: NamespaceName, to_config: MetaStoreHandle, timestamp: Option, - resolve_attach: ResolveNamespacePathFn, store: NamespaceStore, - broadcaster: BroadcasterHandle, ) -> crate::Result { let from_config = from_config.get(); match ns_config.db_kind { @@ -696,10 +377,7 @@ impl Namespace { logger, restore_to, to_config, - ns_config, - resolve_attach, store, - broadcaster: broadcaster.handle(to_ns), }; let ns = fork_task.fork().await?; @@ -708,48 +386,6 @@ impl Namespace { DatabaseKind::Replica => Err(ForkError::ForkReplica.into()), } } - - async fn new_schema( - ns_config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let mut join_set = JoinSet::new(); - let db_path = ns_config.base_path.join("dbs").join(name.as_str()); - - tokio::fs::create_dir_all(&db_path).await?; - - let (connection_maker, wal_manager, stats) = Self::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &name, - restore_option, - Arc::new(AtomicBool::new(false)), // this is always false for schema - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - - Ok(Namespace { - db: Database::Schema(SchemaDatabase::new( - ns_config.migration_scheduler.clone(), - name.clone(), - connection_maker, - wal_manager, - meta_store_handle.clone(), - )), - name, - tasks: join_set, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } } pub struct NamespaceConfig { diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index 984a520154..5a94a7f8eb 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -13,8 +13,10 @@ use tokio_stream::wrappers::BroadcastStream; use crate::auth::Authenticated; use crate::broadcaster::BroadcastMsg; use crate::connection::config::DatabaseConfig; +use crate::database::DatabaseKind; use crate::error::Error; use crate::metrics::NAMESPACE_LOAD_LATENCY; +use crate::namespace::configurator::{PrimaryConfigurator, ReplicaConfigurator, SchemaConfigurator}; use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName}; use crate::stats::Stats; @@ -82,6 +84,12 @@ impl NamespaceStore { .time_to_idle(Duration::from_secs(86400)) .build(); + let mut configurators = NamespaceConfigurators::default(); + configurators + .with_primary(PrimaryConfigurator) + .with_replica(ReplicaConfigurator) + .with_schema(SchemaConfigurator); + Ok(Self { inner: Arc::new(NamespaceStoreInner { store, @@ -92,7 +100,7 @@ impl NamespaceStore { config, schema_locks: Default::default(), broadcasters: Default::default(), - configurators: NamespaceConfigurators::default(), + configurators, }), }) } @@ -177,27 +185,17 @@ impl NamespaceStore { ns.destroy().await?; } - let handle = self.inner.metadata.handle(namespace.clone()); + let db_config = self.inner.metadata.handle(namespace.clone()); // destroy on-disk database Namespace::cleanup( &self.inner.config, &namespace, - &handle.get(), + &db_config.get(), false, NamespaceBottomlessDbIdInit::FetchFromConfig, ) .await?; - let ns = Namespace::from_config( - &self.inner.config, - handle, - restore_option, - &namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ) - .await?; + let ns = self.make_namespace(&namespace, db_config, restore_option).await?; lock.replace(ns); @@ -304,9 +302,7 @@ impl NamespaceStore { to.clone(), handle.clone(), timestamp, - self.resolve_attach_fn(), self.clone(), - self.broadcaster(to), ) .await?; @@ -381,30 +377,42 @@ impl NamespaceStore { .clone() } + pub(crate) async fn make_namespace( + &self, + namespace: &NamespaceName, + config: MetaStoreHandle, + restore_option: RestoreOption, + ) -> crate::Result { + let configurator = match self.inner.config.db_kind { + DatabaseKind::Primary if config.get().is_shared_schema => { + self.inner.configurators.configure_schema()? + } + DatabaseKind::Primary => self.inner.configurators.configure_primary()?, + DatabaseKind::Replica => self.inner.configurators.configure_replica()?, + }; + let ns = configurator.setup( + &self.inner.config, + config, + restore_option, + namespace, + self.make_reset_cb(), + self.resolve_attach_fn(), + self.clone(), + self.broadcaster(namespace.clone()), + ).await?; + + Ok(ns) + } + async fn load_namespace( &self, namespace: &NamespaceName, db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - let init = { - let namespace = namespace.clone(); - async move { - let ns = Namespace::from_config( - &self.inner.config, - db_config, - restore_option, - &namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ) - .await?; - tracing::info!("loaded namespace: `{namespace}`"); - - Ok(Some(ns)) - } + let init = async { + let ns = self.make_namespace(namespace, db_config, restore_option).await?; + Ok(Some(ns)) }; let before_load = Instant::now(); From 907f2f9381783b09254e605473455c72e97cd40b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 5 Aug 2024 11:14:40 +0200 Subject: [PATCH 007/121] pass configurators to NamespaceStore::new --- libsql-server/src/lib.rs | 5 +++ .../src/namespace/configurator/mod.rs | 38 ++++++++++++------- libsql-server/src/namespace/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 10 +---- libsql-server/src/schema/scheduler.rs | 38 ++++++++++++++----- 5 files changed, 62 insertions(+), 31 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 5404a11108..3d816d6bc3 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -60,6 +60,7 @@ use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; use self::connection::connection_manager::InnerWalManager; +use self::namespace::configurator::NamespaceConfigurators; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; @@ -488,12 +489,16 @@ where meta_store_wal_manager, ) .await?; + + let configurators = NamespaceConfigurators::default(); + let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, self.max_active_namespaces, ns_config, meta_store, + configurators, ) .await?; diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index d3cd390b34..a240c3e410 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -4,43 +4,55 @@ use futures::Future; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use super::{ + NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, +}; -mod replica; mod primary; +mod replica; mod schema; -pub use replica::ReplicaConfigurator; pub use primary::PrimaryConfigurator; +pub use replica::ReplicaConfigurator; pub use schema::SchemaConfigurator; type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; -#[derive(Default)] pub(crate) struct NamespaceConfigurators { replica_configurator: Option>, primary_configurator: Option>, schema_configurator: Option>, } +impl Default for NamespaceConfigurators { + fn default() -> Self { + Self::empty() + .with_primary(PrimaryConfigurator) + .with_replica(ReplicaConfigurator) + .with_schema(SchemaConfigurator) + } +} + impl NamespaceConfigurators { - pub fn with_primary( - &mut self, - c: impl ConfigureNamespace + Send + Sync + 'static, - ) -> &mut Self { + pub fn empty() -> Self { + Self { + replica_configurator: None, + primary_configurator: None, + schema_configurator: None, + } + } + + pub fn with_primary(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica( - &mut self, - c: impl ConfigureNamespace + Send + Sync + 'static, - ) -> &mut Self { + pub fn with_replica(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.replica_configurator = Some(Box::new(c)); self } - pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_schema(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.schema_configurator = Some(Box::new(c)); self } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 41bb3ab9cc..5ccda74c54 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -52,7 +52,7 @@ mod name; pub mod replication_wal; mod schema_lock; mod store; -mod configurator; +pub(crate) mod configurator; pub type ResetCb = Box; pub type ResolveNamespacePathFn = diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index 5a94a7f8eb..fbce8cd78b 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -16,7 +16,6 @@ use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; use crate::error::Error; use crate::metrics::NAMESPACE_LOAD_LATENCY; -use crate::namespace::configurator::{PrimaryConfigurator, ReplicaConfigurator, SchemaConfigurator}; use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName}; use crate::stats::Stats; @@ -54,12 +53,13 @@ pub struct NamespaceStoreInner { } impl NamespaceStore { - pub async fn new( + pub(crate) async fn new( allow_lazy_creation: bool, snapshot_at_shutdown: bool, max_active_namespaces: usize, config: NamespaceConfig, metadata: MetaStore, + configurators: NamespaceConfigurators, ) -> crate::Result { tracing::trace!("Max active namespaces: {max_active_namespaces}"); let store = Cache::::builder() @@ -84,12 +84,6 @@ impl NamespaceStore { .time_to_idle(Duration::from_secs(86400)) .build(); - let mut configurators = NamespaceConfigurators::default(); - configurators - .with_primary(PrimaryConfigurator) - .with_replica(ReplicaConfigurator) - .with_schema(SchemaConfigurator); - Ok(Self { inner: Arc::new(NamespaceStoreInner { store, diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 17fdfb3143..17ce655064 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -808,6 +808,9 @@ mod test { use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; + use crate::namespace::configurator::{ + NamespaceConfigurators, PrimaryConfigurator, SchemaConfigurator, + }; use crate::namespace::meta_store::{metastore_connection_maker, MetaStore}; use crate::namespace::{NamespaceConfig, RestoreOption}; use crate::schema::SchedulerHandle; @@ -826,9 +829,16 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + config, + meta_store, + NamespaceConfigurators::default(), + ) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -936,9 +946,16 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + config, + meta_store, + NamespaceConfigurators::default(), + ) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1012,7 +1029,7 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) .await .unwrap(); @@ -1039,7 +1056,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let configurators = NamespaceConfigurators::default() + .with_schema(SchemaConfigurator) + .with_primary(PrimaryConfigurator); + let store = NamespaceStore::new(false, false, 10, config, meta_store, configurators) .await .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) @@ -1112,7 +1132,7 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) .await .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) From fd03144bcc30e7b85f07abf4630ba9ea58a87d11 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 5 Aug 2024 15:26:34 +0200 Subject: [PATCH 008/121] decoupled namespace configurators --- libsql-server/src/error.rs | 2 +- libsql-server/src/lib.rs | 86 ++- .../src/namespace/{ => configurator}/fork.rs | 62 +- .../src/namespace/configurator/helpers.rs | 451 ++++++++++++++ .../src/namespace/configurator/mod.rs | 67 ++- .../src/namespace/configurator/primary.rs | 249 +++++--- .../src/namespace/configurator/replica.rs | 140 +++-- .../src/namespace/configurator/schema.rs | 71 ++- libsql-server/src/namespace/mod.rs | 551 +----------------- libsql-server/src/namespace/store.rs | 103 ++-- libsql-server/src/schema/scheduler.rs | 80 ++- 11 files changed, 1056 insertions(+), 806 deletions(-) rename libsql-server/src/namespace/{ => configurator}/fork.rs (77%) create mode 100644 libsql-server/src/namespace/configurator/helpers.rs diff --git a/libsql-server/src/error.rs b/libsql-server/src/error.rs index 371630abdf..9cd0b81485 100644 --- a/libsql-server/src/error.rs +++ b/libsql-server/src/error.rs @@ -4,7 +4,7 @@ use tonic::metadata::errors::InvalidMetadataValueBytes; use crate::{ auth::AuthError, - namespace::{ForkError, NamespaceName}, + namespace::{configurator::fork::ForkError, NamespaceName}, query_result_builder::QueryResultBuilderError, }; diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 3d816d6bc3..8bd3ea4fac 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -46,7 +46,7 @@ use libsql_wal::registry::WalRegistry; use libsql_wal::storage::NoStorage; use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; -use namespace::{NamespaceConfig, NamespaceName}; +use namespace::NamespaceName; use net::Connector; use once_cell::sync::Lazy; use rusqlite::ffi::SQLITE_CONFIG_MALLOC; @@ -60,7 +60,7 @@ use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; use self::connection::connection_manager::InnerWalManager; -use self::namespace::configurator::NamespaceConfigurators; +use self::namespace::configurator::{BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator}; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; @@ -425,11 +425,6 @@ where let user_auth_strategy = self.user_api_config.auth_strategy.clone(); let service_shutdown = Arc::new(Notify::new()); - let db_kind = if self.rpc_client_config.is_some() { - DatabaseKind::Replica - } else { - DatabaseKind::Primary - }; let scripted_backup = match self.db_config.snapshot_exec { Some(ref command) => { @@ -457,27 +452,6 @@ where // chose the wal backend let (make_wal_manager, registry_shutdown) = self.configure_wal_manager(&mut join_set)?; - let ns_config = NamespaceConfig { - db_kind, - base_path: self.path.clone(), - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - extensions, - stats_sender: stats_sender.clone(), - max_response_size: self.db_config.max_response_size, - max_total_response_size: self.db_config.max_total_response_size, - checkpoint_interval: self.db_config.checkpoint_interval, - encryption_config: self.db_config.encryption_config.clone(), - max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), - scripted_backup, - max_concurrent_requests: self.db_config.max_concurrent_requests, - channel: channel.clone(), - uri: uri.clone(), - migration_scheduler: scheduler_sender.into(), - make_wal_manager, - }; - let (metastore_conn_maker, meta_store_wal_manager) = metastore_connection_maker(self.meta_store_config.bottomless.clone(), &self.path) .await?; @@ -490,15 +464,67 @@ where ) .await?; - let configurators = NamespaceConfigurators::default(); + let base_config = BaseNamespaceConfig { + base_path: self.path.clone(), + extensions, + stats_sender, + max_response_size: self.db_config.max_response_size, + max_total_response_size: self.db_config.max_total_response_size, + max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), + max_concurrent_requests: self.db_config.max_concurrent_requests, + }; + + let mut configurators = NamespaceConfigurators::default(); + + let db_kind = match channel.clone().zip(uri.clone()) { + // replica mode + Some((channel, uri)) => { + let replica_configurator = ReplicaConfigurator::new( + base_config, + channel, + uri, + make_wal_manager, + ); + configurators.with_replica(replica_configurator); + DatabaseKind::Replica + } + // primary mode + None => { + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; + + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); + + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + scheduler_sender.into(), + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); + + DatabaseKind::Primary + }, + }; let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, self.max_active_namespaces, - ns_config, meta_store, configurators, + db_kind, ) .await?; diff --git a/libsql-server/src/namespace/fork.rs b/libsql-server/src/namespace/configurator/fork.rs similarity index 77% rename from libsql-server/src/namespace/fork.rs rename to libsql-server/src/namespace/configurator/fork.rs index f25bf7a9a9..26a0b99b61 100644 --- a/libsql-server/src/namespace/fork.rs +++ b/libsql-server/src/namespace/configurator/fork.rs @@ -12,15 +12,71 @@ use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::time::Duration; use tokio_stream::StreamExt; +use crate::database::Database; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{Namespace, NamespaceBottomlessDbId}; use crate::replication::primary::frame_stream::FrameStream; use crate::replication::{LogReadError, ReplicationLogger}; use crate::{BLOCKING_RT, LIBSQL_PAGE_SIZE}; -use super::meta_store::MetaStoreHandle; -use super::{NamespaceName, NamespaceStore, RestoreOption}; +use super::helpers::make_bottomless_options; +use super::{NamespaceName, NamespaceStore, PrimaryExtraConfig, RestoreOption}; type Result = crate::Result; +pub(super) async fn fork( + from_ns: &Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + primary_config: &PrimaryExtraConfig, + base_path: Arc, +) -> crate::Result { + let from_config = from_config.get(); + let bottomless_db_id = NamespaceBottomlessDbId::from_config(&from_config); + let restore_to = if let Some(timestamp) = timestamp { + if let Some(ref options) = primary_config.bottomless_replication { + Some(PointInTimeRestore { + timestamp, + replicator_options: make_bottomless_options( + options, + bottomless_db_id.clone(), + from_ns.name().clone(), + ), + }) + } else { + return Err(crate::Error::Fork(ForkError::BackupServiceNotConfigured)); + } + } else { + None + }; + + let logger = match &from_ns.db { + Database::Primary(db) => db.wal_wrapper.wrapper().logger(), + Database::Schema(db) => db.wal_wrapper.wrapper().logger(), + _ => { + return Err(crate::Error::Fork(ForkError::Internal(anyhow::Error::msg( + "Invalid source database type for fork", + )))); + } + }; + + let fork_task = ForkTask { + base_path, + to_namespace: to_ns.clone(), + logger, + restore_to, + to_config, + store, + }; + + let ns = fork_task.fork().await?; + + Ok(ns) +} + #[derive(Debug, thiserror::Error)] pub enum ForkError { #[error("internal error: {0}")] @@ -58,7 +114,7 @@ pub struct ForkTask { pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub store: NamespaceStore, + pub store: NamespaceStore } pub struct PointInTimeRestore { diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs new file mode 100644 index 0000000000..f43fa8a192 --- /dev/null +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -0,0 +1,451 @@ +use std::path::{Path, PathBuf}; +use std::sync::Weak; +use std::sync::{atomic::AtomicBool, Arc}; +use std::time::Duration; + +use anyhow::Context as _; +use bottomless::replicator::Options; +use bytes::Bytes; +use futures::Stream; +use libsql_sys::wal::Sqlite3WalManager; +use tokio::io::AsyncBufReadExt as _; +use tokio::sync::watch; +use tokio::task::JoinSet; +use tokio_util::io::StreamReader; +use enclose::enclose; + +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; +use crate::connection::libsql::{open_conn, MakeLibSqlConn}; +use crate::connection::{Connection as _, MakeConnection as _}; +use crate::error::LoadDumpError; +use crate::replication::{FrameNo, ReplicationLogger}; +use crate::stats::Stats; +use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, RestoreOption}; +use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; +use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; + +use super::{BaseNamespaceConfig, PrimaryExtraConfig}; + +const WASM_TABLE_CREATE: &str = + "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;"; + +#[tracing::instrument(skip_all)] +pub(super) async fn make_primary_connection_maker( + primary_config: &PrimaryExtraConfig, + base_config: &BaseNamespaceConfig, + meta_store_handle: &MetaStoreHandle, + db_path: &Path, + name: &NamespaceName, + restore_option: RestoreOption, + block_writes: Arc, + join_set: &mut JoinSet>, + resolve_attach_path: ResolveNamespacePathFn, + broadcaster: BroadcasterHandle, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { + let db_config = meta_store_handle.get(); + let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); + // FIXME: figure how to to it per-db + let mut is_dirty = { + let sentinel_path = db_path.join(".sentinel"); + if sentinel_path.try_exists()? { + true + } else { + tokio::fs::File::create(&sentinel_path).await?; + false + } + }; + + // FIXME: due to a bug in logger::checkpoint_db we call regular checkpointing code + // instead of our virtual WAL one. It's a bit tangled to fix right now, because + // we need WAL context for checkpointing, and WAL context needs the ReplicationLogger... + // So instead we checkpoint early, *before* bottomless gets initialized. That way + // we're sure bottomless won't try to back up any existing WAL frames and will instead + // treat the existing db file as the source of truth. + + let bottomless_replicator = match primary_config.bottomless_replication { + Some(ref options) => { + tracing::debug!("Checkpointing before initializing bottomless"); + crate::replication::primary::logger::checkpoint_db(&db_path.join("data"))?; + tracing::debug!("Checkpointed before initializing bottomless"); + let options = make_bottomless_options(options, bottomless_db_id, name.clone()); + let (replicator, did_recover) = + init_bottomless_replicator(db_path.join("data"), options, &restore_option) + .await?; + tracing::debug!("Completed init of bottomless replicator"); + is_dirty |= did_recover; + Some(replicator) + } + None => None, + }; + + tracing::debug!("Checking fresh db"); + let is_fresh_db = check_fresh_db(&db_path)?; + // switch frame-count checkpoint to time-based one + let auto_checkpoint = if primary_config.checkpoint_interval.is_some() { + 0 + } else { + DEFAULT_AUTO_CHECKPOINT + }; + + let logger = Arc::new(ReplicationLogger::open( + &db_path, + primary_config.max_log_size, + primary_config.max_log_duration, + is_dirty, + auto_checkpoint, + primary_config.scripted_backup.clone(), + name.clone(), + None, + )?); + + tracing::debug!("sending stats"); + + let stats = make_stats( + &db_path, + join_set, + meta_store_handle.clone(), + base_config.stats_sender.clone(), + name.clone(), + logger.new_frame_notifier.subscribe(), + ) + .await?; + + tracing::debug!("Making replication wal wrapper"); + let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); + + tracing::debug!("Opening libsql connection"); + + let connection_maker = MakeLibSqlConn::new( + db_path.to_path_buf(), + wal_wrapper.clone(), + stats.clone(), + broadcaster, + meta_store_handle.clone(), + base_config.extensions.clone(), + base_config.max_response_size, + base_config.max_total_response_size, + auto_checkpoint, + logger.new_frame_notifier.subscribe(), + None, + block_writes, + resolve_attach_path, + make_wal_manager.clone(), + ) + .await? + .throttled( + base_config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + base_config.max_total_response_size, + base_config.max_concurrent_requests, + ); + + tracing::debug!("Completed opening libsql connection"); + + // this must happen after we create the connection maker. The connection maker old on a + // connection to ensure that no other connection is closing while we try to open the dump. + // that would cause a SQLITE_LOCKED error. + match restore_option { + RestoreOption::Dump(_) if !is_fresh_db => { + Err(LoadDumpError::LoadDumpExistingDb)?; + } + RestoreOption::Dump(dump) => { + let conn = connection_maker.create().await?; + tracing::debug!("Loading dump"); + load_dump(dump, conn).await?; + tracing::debug!("Done loading dump"); + } + _ => { /* other cases were already handled when creating bottomless */ } + } + + join_set.spawn(run_periodic_compactions(logger.clone())); + + tracing::debug!("Done making primary connection"); + + Ok((connection_maker, wal_wrapper, stats)) +} + +pub(super) fn make_bottomless_options( + options: &Options, + namespace_db_id: NamespaceBottomlessDbId, + name: NamespaceName, +) -> Options { + let mut options = options.clone(); + let mut db_id = match namespace_db_id { + NamespaceBottomlessDbId::Namespace(id) => id, + // FIXME(marin): I don't like that, if bottomless is enabled, proper config must be passed. + NamespaceBottomlessDbId::NotProvided => options.db_id.unwrap_or_default(), + }; + + db_id = format!("ns-{db_id}:{name}"); + options.db_id = Some(db_id); + options +} + +async fn init_bottomless_replicator( + path: impl AsRef, + options: bottomless::replicator::Options, + restore_option: &RestoreOption, +) -> anyhow::Result<(bottomless::replicator::Replicator, bool)> { + tracing::debug!("Initializing bottomless replication"); + let path = path + .as_ref() + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid db path"))? + .to_owned(); + let mut replicator = bottomless::replicator::Replicator::with_options(path, options).await?; + + let (generation, timestamp) = match restore_option { + RestoreOption::Latest | RestoreOption::Dump(_) => (None, None), + RestoreOption::Generation(generation) => (Some(*generation), None), + RestoreOption::PointInTime(timestamp) => (None, Some(*timestamp)), + }; + + let (action, did_recover) = replicator.restore(generation, timestamp).await?; + match action { + bottomless::replicator::RestoreAction::SnapshotMainDbFile => { + replicator.new_generation().await; + if let Some(_handle) = replicator.snapshot_main_db_file(true).await? { + tracing::trace!("got snapshot handle after restore with generation upgrade"); + } + // Restoration process only leaves the local WAL file if it was + // detected to be newer than its remote counterpart. + replicator.maybe_replicate_wal().await? + } + bottomless::replicator::RestoreAction::ReuseGeneration(gen) => { + replicator.set_generation(gen); + } + } + + Ok((replicator, did_recover)) +} + +async fn run_periodic_compactions(logger: Arc) -> anyhow::Result<()> { + // calling `ReplicationLogger::maybe_compact()` is cheap if the compaction does not actually + // take place, so we can afford to poll it very often for simplicity + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(1000)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + interval.tick().await; + let handle = BLOCKING_RT.spawn_blocking(enclose! {(logger) move || { + logger.maybe_compact() + }}); + handle + .await + .expect("Compaction task crashed") + .context("Compaction failed")?; + } +} + +async fn load_dump(dump: S, conn: PrimaryConnection) -> crate::Result<(), LoadDumpError> +where + S: Stream> + Unpin, +{ + let mut reader = tokio::io::BufReader::new(StreamReader::new(dump)); + let mut curr = String::new(); + let mut line = String::new(); + let mut skipped_wasm_table = false; + let mut n_stmt = 0; + let mut line_id = 0; + + while let Ok(n) = reader.read_line(&mut curr).await { + line_id += 1; + if n == 0 { + break; + } + let trimmed = curr.trim(); + if trimmed.is_empty() || trimmed.starts_with("--") { + curr.clear(); + continue; + } + // FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code + let statement_end = trimmed.ends_with(';'); + + // we want to concat original(non-trimmed) lines as trimming will join all them in one + // single-line statement which is incorrect if comments in the end are present + line.push_str(&curr); + curr.clear(); + + // This is a hack to ignore the libsql_wasm_func_table table because it is already created + // by the system. + if !skipped_wasm_table && line.trim() == WASM_TABLE_CREATE { + skipped_wasm_table = true; + line.clear(); + continue; + } + + if statement_end { + n_stmt += 1; + // dump must be performd within a txn + if n_stmt > 2 && conn.is_autocommit().await.unwrap() { + return Err(LoadDumpError::NoTxn); + } + + line = tokio::task::spawn_blocking({ + let conn = conn.clone(); + move || -> crate::Result { + conn.with_raw(|conn| conn.execute(&line, ())).map_err(|e| { + LoadDumpError::Internal(format!("line: {}, error: {}", line_id, e)) + })?; + Ok(line) + } + }) + .await??; + line.clear(); + } else { + line.push(' '); + } + } + tracing::debug!("loaded {} lines from dump", line_id); + + if !conn.is_autocommit().await.unwrap() { + tokio::task::spawn_blocking({ + let conn = conn.clone(); + move || -> crate::Result<(), LoadDumpError> { + conn.with_raw(|conn| conn.execute("rollback", ()))?; + Ok(()) + } + }) + .await??; + return Err(LoadDumpError::NoCommit); + } + + Ok(()) +} + +fn check_fresh_db(path: &Path) -> crate::Result { + let is_fresh = !path.join("wallog").try_exists()?; + Ok(is_fresh) +} + +pub(super) async fn make_stats( + db_path: &Path, + join_set: &mut JoinSet>, + meta_store_handle: MetaStoreHandle, + stats_sender: StatsSender, + name: NamespaceName, + mut current_frame_no: watch::Receiver>, +) -> anyhow::Result> { + tracing::debug!("creating stats type"); + let stats = Stats::new(name.clone(), db_path, join_set).await?; + + // the storage monitor is optional, so we ignore the error here. + tracing::debug!("stats created, sending stats"); + let _ = stats_sender + .send((name.clone(), meta_store_handle, Arc::downgrade(&stats))) + .await; + + join_set.spawn({ + let stats = stats.clone(); + // initialize the current_frame_no value + current_frame_no + .borrow_and_update() + .map(|fno| stats.set_current_frame_no(fno)); + async move { + while current_frame_no.changed().await.is_ok() { + current_frame_no + .borrow_and_update() + .map(|fno| stats.set_current_frame_no(fno)); + } + Ok(()) + } + }); + + join_set.spawn(run_storage_monitor( + db_path.into(), + Arc::downgrade(&stats), + )); + + tracing::debug!("done sending stats, and creating bg tasks"); + + Ok(stats) +} + +// Periodically check the storage used by the database and save it in the Stats structure. +// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine +// right after checkpointing is exactly where it should be done. +async fn run_storage_monitor( + db_path: PathBuf, + stats: Weak, +) -> anyhow::Result<()> { + // on initialization, the database file doesn't exist yet, so we wait a bit for it to be + // created + tokio::time::sleep(Duration::from_secs(1)).await; + + let duration = tokio::time::Duration::from_secs(60); + let db_path: Arc = db_path.into(); + loop { + let db_path = db_path.clone(); + let Some(stats) = stats.upgrade() else { + return Ok(()); + }; + + let _ = tokio::task::spawn_blocking(move || { + // because closing the last connection interferes with opening a new one, we lazily + // initialize a connection here, and keep it alive for the entirety of the program. If we + // fail to open it, we wait for `duration` and try again later. + match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), None) { + Ok(mut conn) => { + if let Ok(tx) = conn.transaction() { + let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); + let freelist_count = tx.query_row("pragma freelist_count;", [], |row| { row.get::(0) }); + if let (Ok(page_count), Ok(freelist_count)) = (page_count, freelist_count) { + let storage_bytes_used = (page_count - freelist_count) * 4096; + stats.set_storage_bytes_used(storage_bytes_used); + } + } + }, + Err(e) => { + tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); + }, + } + }).await; + + tokio::time::sleep(duration).await; + } +} + +pub(super) async fn cleanup_primary( + base: &BaseNamespaceConfig, + primary_config: &PrimaryExtraConfig, + namespace: &NamespaceName, + db_config: &DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, +) -> crate::Result<()> { + let ns_path = base.base_path.join("dbs").join(namespace.as_str()); + if let Some(ref options) = primary_config.bottomless_replication { + let bottomless_db_id = match bottomless_db_id_init { + NamespaceBottomlessDbIdInit::Provided(db_id) => db_id, + NamespaceBottomlessDbIdInit::FetchFromConfig => { + NamespaceBottomlessDbId::from_config(db_config) + } + }; + let options = make_bottomless_options(options, bottomless_db_id, namespace.clone()); + let replicator = bottomless::replicator::Replicator::with_options( + ns_path.join("data").to_str().unwrap(), + options, + ) + .await?; + if prune_all { + let delete_all = replicator.delete_all(None).await?; + // perform hard deletion in the background + tokio::spawn(delete_all.commit()); + } else { + // for soft delete make sure that local db is fully backed up + replicator.savepoint().confirmed().await?; + } + } + + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + + Ok(()) +} diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index a240c3e410..e5db335ff6 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -1,22 +1,51 @@ +use std::path::{Path, PathBuf}; use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use chrono::NaiveDateTime; use futures::Future; +use tokio::sync::Semaphore; + +use crate::connection::config::DatabaseConfig; +use crate::replication::script_backup_manager::ScriptBackupManager; +use crate::StatsSender; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{ - NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, -}; +use super::{Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +mod helpers; mod primary; mod replica; mod schema; +pub mod fork; pub use primary::PrimaryConfigurator; pub use replica::ReplicaConfigurator; pub use schema::SchemaConfigurator; -type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; +#[derive(Clone, Debug)] +pub struct BaseNamespaceConfig { + pub(crate) base_path: Arc, + pub(crate) extensions: Arc<[PathBuf]>, + pub(crate) stats_sender: StatsSender, + pub(crate) max_response_size: u64, + pub(crate) max_total_response_size: u64, + pub(crate) max_concurrent_connections: Arc, + pub(crate) max_concurrent_requests: u64, +} + +#[derive(Clone)] +pub struct PrimaryExtraConfig { + pub(crate) max_log_size: u64, + pub(crate) max_log_duration: Option, + pub(crate) bottomless_replication: Option, + pub(crate) scripted_backup: Option, + pub(crate) checkpoint_interval: Option, +} + +pub type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; pub(crate) struct NamespaceConfigurators { replica_configurator: Option>, @@ -27,9 +56,6 @@ pub(crate) struct NamespaceConfigurators { impl Default for NamespaceConfigurators { fn default() -> Self { Self::empty() - .with_primary(PrimaryConfigurator) - .with_replica(ReplicaConfigurator) - .with_schema(SchemaConfigurator) } } @@ -42,17 +68,17 @@ impl NamespaceConfigurators { } } - pub fn with_primary(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_primary(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_replica(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.replica_configurator = Some(Box::new(c)); self } - pub fn with_schema(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.schema_configurator = Some(Box::new(c)); self } @@ -73,7 +99,6 @@ impl NamespaceConfigurators { pub trait ConfigureNamespace { fn setup<'a>( &'a self, - ns_config: &'a NamespaceConfig, db_config: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -81,5 +106,23 @@ pub trait ConfigureNamespace { resolve_attach_path: ResolveNamespacePathFn, store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>>; + ) -> Pin> + Send + 'a>>; + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>>; + + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> Pin> + Send + 'a>>; } diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index f28d288a97..4351f6a3ac 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -4,22 +4,117 @@ use std::{path::Path, pin::Pin, sync::Arc}; use futures::prelude::Future; use tokio::task::JoinSet; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::connection::MakeConnection; use crate::database::{Database, PrimaryDatabase}; -use crate::namespace::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; -use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_primary_connection_maker; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, + ResetCb, ResolveNamespacePathFn, RestoreOption, +}; use crate::run_periodic_checkpoint; use crate::schema::{has_pending_migration_task, setup_migration_table}; -use super::ConfigureNamespace; +use super::helpers::cleanup_primary; +use super::{BaseNamespaceConfig, ConfigureNamespace, PrimaryExtraConfig}; + +pub struct PrimaryConfigurator { + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl PrimaryConfigurator { + pub fn new( + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + ) -> Self { + Self { + base, + primary_config, + make_wal_manager, + } + } + + #[tracing::instrument(skip_all, fields(namespace))] + async fn try_new_primary( + &self, + namespace: NamespaceName, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + resolve_attach_path: ResolveNamespacePathFn, + db_path: Arc, + broadcaster: BroadcasterHandle, + ) -> crate::Result { + let mut join_set = JoinSet::new(); + + tokio::fs::create_dir_all(&db_path).await?; + + let block_writes = Arc::new(AtomicBool::new(false)); + let (connection_maker, wal_wrapper, stats) = make_primary_connection_maker( + &self.primary_config, + &self.base, + &meta_store_handle, + &db_path, + &namespace, + restore_option, + block_writes.clone(), + &mut join_set, + resolve_attach_path, + broadcaster, + self.make_wal_manager.clone(), + ) + .await?; + let connection_maker = Arc::new(connection_maker); + + if meta_store_handle.get().shared_schema_name.is_some() { + let block_writes = block_writes.clone(); + let conn = connection_maker.create().await?; + tokio::task::spawn_blocking(move || { + conn.with_raw(|conn| -> crate::Result<()> { + setup_migration_table(conn)?; + if has_pending_migration_task(conn)? { + block_writes.store(true, Ordering::SeqCst); + } + Ok(()) + }) + }) + .await + .unwrap()?; + } + + if let Some(checkpoint_interval) = self.primary_config.checkpoint_interval { + join_set.spawn(run_periodic_checkpoint( + connection_maker.clone(), + checkpoint_interval, + namespace.clone(), + )); + } + + tracing::debug!("Done making new primary"); -pub struct PrimaryConfigurator; + Ok(Namespace { + tasks: join_set, + db: Database::Primary(PrimaryDatabase { + wal_wrapper, + connection_maker, + block_writes, + }), + name: namespace, + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) + } +} impl ConfigureNamespace for PrimaryConfigurator { fn setup<'a>( &'a self, - config: &'a NamespaceConfig, meta_store_handle: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -27,102 +122,74 @@ impl ConfigureNamespace for PrimaryConfigurator { resolve_attach_path: ResolveNamespacePathFn, _store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>> - { + ) -> Pin> + Send + 'a>> { Box::pin(async move { - let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); + let db_path: Arc = self.base.base_path.join("dbs").join(name.as_str()).into(); let fresh_namespace = !db_path.try_exists()?; // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications - match try_new_primary( - config, - name.clone(), - meta_store_handle, - restore_option, - resolve_attach_path, - db_path.clone(), - broadcaster, - ) + match self + .try_new_primary( + name.clone(), + meta_store_handle, + restore_option, + resolve_attach_path, + db_path.clone(), + broadcaster, + ) .await - { - Ok(this) => Ok(this), - Err(e) if fresh_namespace => { - tracing::error!("an error occured while deleting creating namespace, cleaning..."); - if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { - tracing::error!("failed to remove dirty namespace directory: {e}") - } - Err(e) + { + Ok(this) => Ok(this), + Err(e) if fresh_namespace => { + tracing::error!( + "an error occured while deleting creating namespace, cleaning..." + ); + if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { + tracing::error!("failed to remove dirty namespace directory: {e}") } - Err(e) => Err(e), + Err(e) } + Err(e) => Err(e), + } }) } -} -#[tracing::instrument(skip_all, fields(namespace))] -async fn try_new_primary( - ns_config: &NamespaceConfig, - namespace: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - db_path: Arc, - broadcaster: BroadcasterHandle, -) -> crate::Result { - let mut join_set = JoinSet::new(); - - tokio::fs::create_dir_all(&db_path).await?; - - let block_writes = Arc::new(AtomicBool::new(false)); - let (connection_maker, wal_wrapper, stats) = Namespace::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &namespace, - restore_option, - block_writes.clone(), - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - let connection_maker = Arc::new(connection_maker); - - if meta_store_handle.get().shared_schema_name.is_some() { - let block_writes = block_writes.clone(); - let conn = connection_maker.create().await?; - tokio::task::spawn_blocking(move || { - conn.with_raw(|conn| -> crate::Result<()> { - setup_migration_table(conn)?; - if has_pending_migration_task(conn)? { - block_writes.store(true, Ordering::SeqCst); - } - Ok(()) - }) + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + cleanup_primary( + &self.base, + &self.primary_config, + namespace, + db_config, + prune_all, + bottomless_db_id_init, + ).await }) - .await - .unwrap()?; - } - - if let Some(checkpoint_interval) = ns_config.checkpoint_interval { - join_set.spawn(run_periodic_checkpoint( - connection_maker.clone(), - checkpoint_interval, - namespace.clone(), - )); } - tracing::debug!("Done making new primary"); - - Ok(Namespace { - tasks: join_set, - db: Database::Primary(PrimaryDatabase { - wal_wrapper, - connection_maker, - block_writes, - }), - name: namespace, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(super::fork::fork( + from_ns, + from_config, + to_ns, + to_config, + timestamp, + store, + &self.primary_config, + self.base.base_path.clone())) + } } + diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 4d3ca1dadf..61dd48b0bf 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -2,29 +2,51 @@ use std::pin::Pin; use std::sync::Arc; use futures::Future; +use hyper::Uri; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use tokio::task::JoinSet; +use tonic::transport::Channel; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::connection::write_proxy::MakeWriteProxyConn; use crate::connection::MakeConnection; use crate::database::{Database, ReplicaDatabase}; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_stats; use crate::namespace::meta_store::MetaStoreHandle; -use crate::namespace::{Namespace, RestoreOption}; -use crate::namespace::{ - make_stats, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResetOp, - ResolveNamespacePathFn, -}; +use crate::namespace::{Namespace, NamespaceBottomlessDbIdInit, RestoreOption}; +use crate::namespace::{NamespaceName, NamespaceStore, ResetCb, ResetOp, ResolveNamespacePathFn}; use crate::{DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; -use super::ConfigureNamespace; +use super::{BaseNamespaceConfig, ConfigureNamespace}; -pub struct ReplicaConfigurator; +pub struct ReplicaConfigurator { + base: BaseNamespaceConfig, + channel: Channel, + uri: Uri, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl ReplicaConfigurator { + pub fn new( + base: BaseNamespaceConfig, + channel: Channel, + uri: Uri, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + ) -> Self { + Self { + base, + channel, + uri, + make_wal_manager, + } + } +} impl ConfigureNamespace for ReplicaConfigurator { fn setup<'a>( &'a self, - config: &'a NamespaceConfig, meta_store_handle: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -32,13 +54,12 @@ impl ConfigureNamespace for ReplicaConfigurator { resolve_attach_path: ResolveNamespacePathFn, store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>> - { + ) -> Pin> + Send + 'a>> { Box::pin(async move { tracing::debug!("creating replica namespace"); - let db_path = config.base_path.join("dbs").join(name.as_str()); - let channel = config.channel.clone().expect("bad replica config"); - let uri = config.uri.clone().expect("bad replica config"); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); + let channel = self.channel.clone(); + let uri = self.uri.clone(); let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); let client = crate::replication::replicator_client::Client::new( @@ -48,45 +69,46 @@ impl ConfigureNamespace for ReplicaConfigurator { meta_store_handle.clone(), store.clone(), ) - .await?; + .await?; let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); let mut replicator = libsql_replication::replicator::Replicator::new( client, db_path.join("data"), DEFAULT_AUTO_CHECKPOINT, - config.encryption_config.clone(), + None, ) - .await?; + .await?; tracing::debug!("try perform handshake"); // force a handshake now, to retrieve the primary's current replication index match replicator.try_perform_handshake().await { Err(libsql_replication::replicator::Error::Meta( - libsql_replication::meta::Error::LogIncompatible, + libsql_replication::meta::Error::LogIncompatible, )) => { tracing::error!( "trying to replicate incompatible logs, reseting replica and nuking db dir" ); std::fs::remove_dir_all(&db_path).unwrap(); - return self.setup( - config, - meta_store_handle, - restore_option, - name, - reset, - resolve_attach_path, - store, - broadcaster, - ) + return self + .setup( + meta_store_handle, + restore_option, + name, + reset, + resolve_attach_path, + store, + broadcaster, + ) .await; - } + } Err(e) => Err(e)?, Ok(_) => (), } tracing::debug!("done performing handshake"); - let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; + let primary_current_replicatio_index = + replicator.client_mut().primary_replication_index; let mut join_set = JoinSet::new(); let namespace = name.clone(); @@ -144,36 +166,35 @@ impl ConfigureNamespace for ReplicaConfigurator { &db_path, &mut join_set, meta_store_handle.clone(), - config.stats_sender.clone(), + self.base.stats_sender.clone(), name.clone(), applied_frame_no_receiver.clone(), - config.encryption_config.clone(), ) - .await?; + .await?; let connection_maker = MakeWriteProxyConn::new( db_path.clone(), - config.extensions.clone(), + self.base.extensions.clone(), channel.clone(), uri.clone(), stats.clone(), broadcaster, meta_store_handle.clone(), applied_frame_no_receiver, - config.max_response_size, - config.max_total_response_size, + self.base.max_response_size, + self.base.max_total_response_size, primary_current_replicatio_index, - config.encryption_config.clone(), + None, resolve_attach_path, - config.make_wal_manager.clone(), + self.make_wal_manager.clone(), ) - .await? - .throttled( - config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - config.max_total_response_size, - config.max_concurrent_requests, - ); + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); Ok(Namespace { tasks: join_set, @@ -187,4 +208,35 @@ impl ConfigureNamespace for ReplicaConfigurator { }) }) } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + _db_config: &DatabaseConfig, + _prune_all: bool, + _bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let ns_path = self.base.base_path.join("dbs").join(namespace.as_str()); + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + Ok(()) + }) + } + + fn fork<'a>( + &'a self, + _from_ns: &'a Namespace, + _from_config: MetaStoreHandle, + _to_ns: NamespaceName, + _to_config: MetaStoreHandle, + _timestamp: Option, + _store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(std::future::ready(Err(crate::Error::Fork( + super::fork::ForkError::ForkReplica, + )))) + } } diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index 864b75239f..e55c706fec 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -3,22 +3,36 @@ use std::sync::{atomic::AtomicBool, Arc}; use futures::prelude::Future; use tokio::task::JoinSet; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::database::{Database, SchemaDatabase}; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceConfig, NamespaceName, NamespaceStore, + Namespace, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, }; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::schema::SchedulerHandle; -use super::ConfigureNamespace; +use super::helpers::{cleanup_primary, make_primary_connection_maker}; +use super::{BaseNamespaceConfig, ConfigureNamespace, PrimaryExtraConfig}; -pub struct SchemaConfigurator; +pub struct SchemaConfigurator { + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler: SchedulerHandle, +} + +impl SchemaConfigurator { + pub fn new(base: BaseNamespaceConfig, primary_config: PrimaryExtraConfig, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, migration_scheduler: SchedulerHandle) -> Self { + Self { base, primary_config, make_wal_manager, migration_scheduler } + } +} impl ConfigureNamespace for SchemaConfigurator { fn setup<'a>( &'a self, - ns_config: &'a NamespaceConfig, db_config: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -29,12 +43,13 @@ impl ConfigureNamespace for SchemaConfigurator { ) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let mut join_set = JoinSet::new(); - let db_path = ns_config.base_path.join("dbs").join(name.as_str()); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); tokio::fs::create_dir_all(&db_path).await?; - let (connection_maker, wal_manager, stats) = Namespace::make_primary_connection_maker( - ns_config, + let (connection_maker, wal_manager, stats) = make_primary_connection_maker( + &self.primary_config, + &self.base, &db_config, &db_path, &name, @@ -43,12 +58,13 @@ impl ConfigureNamespace for SchemaConfigurator { &mut join_set, resolve_attach_path, broadcaster, + self.make_wal_manager.clone() ) .await?; Ok(Namespace { db: Database::Schema(SchemaDatabase::new( - ns_config.migration_scheduler.clone(), + self.migration_scheduler.clone(), name.clone(), connection_maker, wal_manager, @@ -62,4 +78,43 @@ impl ConfigureNamespace for SchemaConfigurator { }) }) } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: crate::namespace::NamespaceBottomlessDbIdInit, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + cleanup_primary( + &self.base, + &self.primary_config, + namespace, + db_config, + prune_all, + bottomless_db_id_init, + ).await + }) + } + + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(super::fork::fork( + from_ns, + from_config, + to_ns, + to_config, + timestamp, + store, + &self.primary_config, + self.base.base_path.clone())) + } } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 5ccda74c54..7cfa6b351c 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,52 +1,24 @@ -use std::path::{Path, PathBuf}; -use std::sync::atomic::AtomicBool; -use std::sync::{Arc, Weak}; +use std::path::Path; +use std::sync::Arc; -use anyhow::{Context as _, Error}; -use bottomless::replicator::Options; -use broadcasters::BroadcasterHandle; +use anyhow::Context as _; use bytes::Bytes; use chrono::NaiveDateTime; -use enclose::enclose; use futures_core::{Future, Stream}; -use hyper::Uri; -use libsql_sys::wal::Sqlite3WalManager; -use libsql_sys::EncryptionConfig; -use tokio::io::AsyncBufReadExt; -use tokio::sync::{watch, Semaphore}; use tokio::task::JoinSet; -use tokio::time::Duration; -use tokio_util::io::StreamReader; -use tonic::transport::Channel; use uuid::Uuid; use crate::auth::parse_jwt_keys; use crate::connection::config::DatabaseConfig; -use crate::connection::connection_manager::InnerWalManager; -use crate::connection::libsql::{open_conn, MakeLibSqlConn}; -use crate::connection::{Connection as _, MakeConnection}; -use crate::database::{ - Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, -}; -use crate::error::LoadDumpError; -use crate::replication::script_backup_manager::ScriptBackupManager; -use crate::replication::{FrameNo, ReplicationLogger}; -use crate::schema::SchedulerHandle; +use crate::connection::Connection as _; +use crate::database::Database; use crate::stats::Stats; -use crate::{ - StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, -}; -pub use fork::ForkError; - -use self::fork::{ForkTask, PointInTimeRestore}; use self::meta_store::MetaStoreHandle; pub use self::name::NamespaceName; -use self::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; pub use self::store::NamespaceStore; pub mod broadcasters; -mod fork; pub mod meta_store; mod name; pub mod replication_wal; @@ -101,51 +73,6 @@ impl Namespace { &self.name } - /// completely remove resources associated with the namespace - pub(crate) async fn cleanup( - ns_config: &NamespaceConfig, - name: &NamespaceName, - db_config: &DatabaseConfig, - prune_all: bool, - bottomless_db_id_init: NamespaceBottomlessDbIdInit, - ) -> crate::Result<()> { - let ns_path = ns_config.base_path.join("dbs").join(name.as_str()); - match ns_config.db_kind { - DatabaseKind::Primary => { - if let Some(ref options) = ns_config.bottomless_replication { - let bottomless_db_id = match bottomless_db_id_init { - NamespaceBottomlessDbIdInit::Provided(db_id) => db_id, - NamespaceBottomlessDbIdInit::FetchFromConfig => { - NamespaceBottomlessDbId::from_config(&db_config) - } - }; - let options = make_bottomless_options(options, bottomless_db_id, name.clone()); - let replicator = bottomless::replicator::Replicator::with_options( - ns_path.join("data").to_str().unwrap(), - options, - ) - .await?; - if prune_all { - let delete_all = replicator.delete_all(None).await?; - // perform hard deletion in the background - tokio::spawn(delete_all.commit()); - } else { - // for soft delete make sure that local db is fully backed up - replicator.savepoint().confirmed().await?; - } - } - } - DatabaseKind::Replica => (), - } - - if ns_path.try_exists()? { - tracing::debug!("removing database directory: {}", ns_path.display()); - tokio::fs::remove_dir_all(ns_path).await?; - } - - Ok(()) - } - async fn destroy(mut self) -> anyhow::Result<()> { self.tasks.shutdown().await; self.db.destroy(); @@ -195,293 +122,11 @@ impl Namespace { pub fn config_changed(&self) -> impl Future { self.db_config_store.changed() } - - #[tracing::instrument(skip_all)] - async fn make_primary_connection_maker( - ns_config: &NamespaceConfig, - meta_store_handle: &MetaStoreHandle, - db_path: &Path, - name: &NamespaceName, - restore_option: RestoreOption, - block_writes: Arc, - join_set: &mut JoinSet>, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { - let db_config = meta_store_handle.get(); - let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); - // FIXME: figure how to to it per-db - let mut is_dirty = { - let sentinel_path = db_path.join(".sentinel"); - if sentinel_path.try_exists()? { - true - } else { - tokio::fs::File::create(&sentinel_path).await?; - false - } - }; - - // FIXME: due to a bug in logger::checkpoint_db we call regular checkpointing code - // instead of our virtual WAL one. It's a bit tangled to fix right now, because - // we need WAL context for checkpointing, and WAL context needs the ReplicationLogger... - // So instead we checkpoint early, *before* bottomless gets initialized. That way - // we're sure bottomless won't try to back up any existing WAL frames and will instead - // treat the existing db file as the source of truth. - - let bottomless_replicator = match ns_config.bottomless_replication { - Some(ref options) => { - tracing::debug!("Checkpointing before initializing bottomless"); - crate::replication::primary::logger::checkpoint_db(&db_path.join("data"))?; - tracing::debug!("Checkpointed before initializing bottomless"); - let options = make_bottomless_options(options, bottomless_db_id, name.clone()); - let (replicator, did_recover) = - init_bottomless_replicator(db_path.join("data"), options, &restore_option) - .await?; - tracing::debug!("Completed init of bottomless replicator"); - is_dirty |= did_recover; - Some(replicator) - } - None => None, - }; - - tracing::debug!("Checking fresh db"); - let is_fresh_db = check_fresh_db(&db_path)?; - // switch frame-count checkpoint to time-based one - let auto_checkpoint = if ns_config.checkpoint_interval.is_some() { - 0 - } else { - DEFAULT_AUTO_CHECKPOINT - }; - - let logger = Arc::new(ReplicationLogger::open( - &db_path, - ns_config.max_log_size, - ns_config.max_log_duration, - is_dirty, - auto_checkpoint, - ns_config.scripted_backup.clone(), - name.clone(), - ns_config.encryption_config.clone(), - )?); - - tracing::debug!("sending stats"); - - let stats = make_stats( - &db_path, - join_set, - meta_store_handle.clone(), - ns_config.stats_sender.clone(), - name.clone(), - logger.new_frame_notifier.subscribe(), - ns_config.encryption_config.clone(), - ) - .await?; - - tracing::debug!("Making replication wal wrapper"); - let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); - - tracing::debug!("Opening libsql connection"); - - let connection_maker = MakeLibSqlConn::new( - db_path.to_path_buf(), - wal_wrapper.clone(), - stats.clone(), - broadcaster, - meta_store_handle.clone(), - ns_config.extensions.clone(), - ns_config.max_response_size, - ns_config.max_total_response_size, - auto_checkpoint, - logger.new_frame_notifier.subscribe(), - ns_config.encryption_config.clone(), - block_writes, - resolve_attach_path, - ns_config.make_wal_manager.clone(), - ) - .await? - .throttled( - ns_config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - ns_config.max_total_response_size, - ns_config.max_concurrent_requests, - ); - - tracing::debug!("Completed opening libsql connection"); - - // this must happen after we create the connection maker. The connection maker old on a - // connection to ensure that no other connection is closing while we try to open the dump. - // that would cause a SQLITE_LOCKED error. - match restore_option { - RestoreOption::Dump(_) if !is_fresh_db => { - Err(LoadDumpError::LoadDumpExistingDb)?; - } - RestoreOption::Dump(dump) => { - let conn = connection_maker.create().await?; - tracing::debug!("Loading dump"); - load_dump(dump, conn).await?; - tracing::debug!("Done loading dump"); - } - _ => { /* other cases were already handled when creating bottomless */ } - } - - join_set.spawn(run_periodic_compactions(logger.clone())); - - tracing::debug!("Done making primary connection"); - - Ok((connection_maker, wal_wrapper, stats)) - } - - async fn fork( - ns_config: &NamespaceConfig, - from_ns: &Namespace, - from_config: MetaStoreHandle, - to_ns: NamespaceName, - to_config: MetaStoreHandle, - timestamp: Option, - store: NamespaceStore, - ) -> crate::Result { - let from_config = from_config.get(); - match ns_config.db_kind { - DatabaseKind::Primary => { - let bottomless_db_id = NamespaceBottomlessDbId::from_config(&from_config); - let restore_to = if let Some(timestamp) = timestamp { - if let Some(ref options) = ns_config.bottomless_replication { - Some(PointInTimeRestore { - timestamp, - replicator_options: make_bottomless_options( - options, - bottomless_db_id.clone(), - from_ns.name().clone(), - ), - }) - } else { - return Err(crate::Error::Fork(ForkError::BackupServiceNotConfigured)); - } - } else { - None - }; - - let logger = match &from_ns.db { - Database::Primary(db) => db.wal_wrapper.wrapper().logger(), - Database::Schema(db) => db.wal_wrapper.wrapper().logger(), - _ => { - return Err(crate::Error::Fork(ForkError::Internal(Error::msg( - "Invalid source database type for fork", - )))); - } - }; - - let fork_task = ForkTask { - base_path: ns_config.base_path.clone(), - to_namespace: to_ns.clone(), - logger, - restore_to, - to_config, - store, - }; - - let ns = fork_task.fork().await?; - Ok(ns) - } - DatabaseKind::Replica => Err(ForkError::ForkReplica.into()), - } - } -} - -pub struct NamespaceConfig { - /// Default database kind the store should be Creating - pub(crate) db_kind: DatabaseKind, - // Common config - pub(crate) base_path: Arc, - pub(crate) max_log_size: u64, - pub(crate) max_log_duration: Option, - pub(crate) extensions: Arc<[PathBuf]>, - pub(crate) stats_sender: StatsSender, - pub(crate) max_response_size: u64, - pub(crate) max_total_response_size: u64, - pub(crate) checkpoint_interval: Option, - pub(crate) max_concurrent_connections: Arc, - pub(crate) max_concurrent_requests: u64, - pub(crate) encryption_config: Option, - - // Replica specific config - /// grpc channel for replica - pub channel: Option, - /// grpc uri - pub uri: Option, - - // primary only config - pub(crate) bottomless_replication: Option, - pub(crate) scripted_backup: Option, - pub(crate) migration_scheduler: SchedulerHandle, - pub(crate) make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, } pub type DumpStream = Box> + Send + Sync + 'static + Unpin>; -fn make_bottomless_options( - options: &Options, - namespace_db_id: NamespaceBottomlessDbId, - name: NamespaceName, -) -> Options { - let mut options = options.clone(); - let mut db_id = match namespace_db_id { - NamespaceBottomlessDbId::Namespace(id) => id, - // FIXME(marin): I don't like that, if bottomless is enabled, proper config must be passed. - NamespaceBottomlessDbId::NotProvided => options.db_id.unwrap_or_default(), - }; - - db_id = format!("ns-{db_id}:{name}"); - options.db_id = Some(db_id); - options -} - -async fn make_stats( - db_path: &Path, - join_set: &mut JoinSet>, - meta_store_handle: MetaStoreHandle, - stats_sender: StatsSender, - name: NamespaceName, - mut current_frame_no: watch::Receiver>, - encryption_config: Option, -) -> anyhow::Result> { - tracing::debug!("creating stats type"); - let stats = Stats::new(name.clone(), db_path, join_set).await?; - - // the storage monitor is optional, so we ignore the error here. - tracing::debug!("stats created, sending stats"); - let _ = stats_sender - .send((name.clone(), meta_store_handle, Arc::downgrade(&stats))) - .await; - - join_set.spawn({ - let stats = stats.clone(); - // initialize the current_frame_no value - current_frame_no - .borrow_and_update() - .map(|fno| stats.set_current_frame_no(fno)); - async move { - while current_frame_no.changed().await.is_ok() { - current_frame_no - .borrow_and_update() - .map(|fno| stats.set_current_frame_no(fno)); - } - Ok(()) - } - }); - - join_set.spawn(run_storage_monitor( - db_path.into(), - Arc::downgrade(&stats), - encryption_config, - )); - - tracing::debug!("done sending stats, and creating bg tasks"); - - Ok(stats) -} - #[derive(Default)] pub enum RestoreOption { /// Restore database state from the most recent version found in a backup. @@ -495,189 +140,3 @@ pub enum RestoreOption { /// Granularity depends of how frequently WAL log pages are being snapshotted. PointInTime(NaiveDateTime), } - -const WASM_TABLE_CREATE: &str = - "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;"; - -async fn load_dump(dump: S, conn: PrimaryConnection) -> crate::Result<(), LoadDumpError> -where - S: Stream> + Unpin, -{ - let mut reader = tokio::io::BufReader::new(StreamReader::new(dump)); - let mut curr = String::new(); - let mut line = String::new(); - let mut skipped_wasm_table = false; - let mut n_stmt = 0; - let mut line_id = 0; - - while let Ok(n) = reader.read_line(&mut curr).await { - line_id += 1; - if n == 0 { - break; - } - let trimmed = curr.trim(); - if trimmed.is_empty() || trimmed.starts_with("--") { - curr.clear(); - continue; - } - // FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code - let statement_end = trimmed.ends_with(';'); - - // we want to concat original(non-trimmed) lines as trimming will join all them in one - // single-line statement which is incorrect if comments in the end are present - line.push_str(&curr); - curr.clear(); - - // This is a hack to ignore the libsql_wasm_func_table table because it is already created - // by the system. - if !skipped_wasm_table && line.trim() == WASM_TABLE_CREATE { - skipped_wasm_table = true; - line.clear(); - continue; - } - - if statement_end { - n_stmt += 1; - // dump must be performd within a txn - if n_stmt > 2 && conn.is_autocommit().await.unwrap() { - return Err(LoadDumpError::NoTxn); - } - - line = tokio::task::spawn_blocking({ - let conn = conn.clone(); - move || -> crate::Result { - conn.with_raw(|conn| conn.execute(&line, ())).map_err(|e| { - LoadDumpError::Internal(format!("line: {}, error: {}", line_id, e)) - })?; - Ok(line) - } - }) - .await??; - line.clear(); - } else { - line.push(' '); - } - } - tracing::debug!("loaded {} lines from dump", line_id); - - if !conn.is_autocommit().await.unwrap() { - tokio::task::spawn_blocking({ - let conn = conn.clone(); - move || -> crate::Result<(), LoadDumpError> { - conn.with_raw(|conn| conn.execute("rollback", ()))?; - Ok(()) - } - }) - .await??; - return Err(LoadDumpError::NoCommit); - } - - Ok(()) -} - -pub async fn init_bottomless_replicator( - path: impl AsRef, - options: bottomless::replicator::Options, - restore_option: &RestoreOption, -) -> anyhow::Result<(bottomless::replicator::Replicator, bool)> { - tracing::debug!("Initializing bottomless replication"); - let path = path - .as_ref() - .to_str() - .ok_or_else(|| anyhow::anyhow!("Invalid db path"))? - .to_owned(); - let mut replicator = bottomless::replicator::Replicator::with_options(path, options).await?; - - let (generation, timestamp) = match restore_option { - RestoreOption::Latest | RestoreOption::Dump(_) => (None, None), - RestoreOption::Generation(generation) => (Some(*generation), None), - RestoreOption::PointInTime(timestamp) => (None, Some(*timestamp)), - }; - - let (action, did_recover) = replicator.restore(generation, timestamp).await?; - match action { - bottomless::replicator::RestoreAction::SnapshotMainDbFile => { - replicator.new_generation().await; - if let Some(_handle) = replicator.snapshot_main_db_file(true).await? { - tracing::trace!("got snapshot handle after restore with generation upgrade"); - } - // Restoration process only leaves the local WAL file if it was - // detected to be newer than its remote counterpart. - replicator.maybe_replicate_wal().await? - } - bottomless::replicator::RestoreAction::ReuseGeneration(gen) => { - replicator.set_generation(gen); - } - } - - Ok((replicator, did_recover)) -} - -async fn run_periodic_compactions(logger: Arc) -> anyhow::Result<()> { - // calling `ReplicationLogger::maybe_compact()` is cheap if the compaction does not actually - // take place, so we can afford to poll it very often for simplicity - let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(1000)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - interval.tick().await; - let handle = BLOCKING_RT.spawn_blocking(enclose! {(logger) move || { - logger.maybe_compact() - }}); - handle - .await - .expect("Compaction task crashed") - .context("Compaction failed")?; - } -} - -fn check_fresh_db(path: &Path) -> crate::Result { - let is_fresh = !path.join("wallog").try_exists()?; - Ok(is_fresh) -} - -// Periodically check the storage used by the database and save it in the Stats structure. -// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine -// right after checkpointing is exactly where it should be done. -async fn run_storage_monitor( - db_path: PathBuf, - stats: Weak, - encryption_config: Option, -) -> anyhow::Result<()> { - // on initialization, the database file doesn't exist yet, so we wait a bit for it to be - // created - tokio::time::sleep(Duration::from_secs(1)).await; - - let duration = tokio::time::Duration::from_secs(60); - let db_path: Arc = db_path.into(); - loop { - let db_path = db_path.clone(); - let Some(stats) = stats.upgrade() else { - return Ok(()); - }; - - let encryption_config = encryption_config.clone(); - let _ = tokio::task::spawn_blocking(move || { - // because closing the last connection interferes with opening a new one, we lazily - // initialize a connection here, and keep it alive for the entirety of the program. If we - // fail to open it, we wait for `duration` and try again later. - match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), encryption_config) { - Ok(mut conn) => { - if let Ok(tx) = conn.transaction() { - let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); - let freelist_count = tx.query_row("pragma freelist_count;", [], |row| { row.get::(0) }); - if let (Ok(page_count), Ok(freelist_count)) = (page_count, freelist_count) { - let storage_bytes_used = (page_count - freelist_count) * 4096; - stats.set_storage_bytes_used(storage_bytes_used); - } - } - }, - Err(e) => { - tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); - }, - } - }).await; - - tokio::time::sleep(duration).await; - } -} diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index fbce8cd78b..a78e4f59b0 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -20,10 +20,10 @@ use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, Nam use crate::stats::Stats; use super::broadcasters::{BroadcasterHandle, BroadcasterRegistry}; -use super::configurator::NamespaceConfigurators; +use super::configurator::{DynConfigurator, NamespaceConfigurators}; use super::meta_store::{MetaStore, MetaStoreHandle}; use super::schema_lock::SchemaLocksRegistry; -use super::{Namespace, NamespaceConfig, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; +use super::{Namespace, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; type NamespaceEntry = Arc>>; @@ -46,10 +46,10 @@ pub struct NamespaceStoreInner { allow_lazy_creation: bool, has_shutdown: AtomicBool, snapshot_at_shutdown: bool, - pub config: NamespaceConfig, schema_locks: SchemaLocksRegistry, broadcasters: BroadcasterRegistry, configurators: NamespaceConfigurators, + db_kind: DatabaseKind, } impl NamespaceStore { @@ -57,9 +57,9 @@ impl NamespaceStore { allow_lazy_creation: bool, snapshot_at_shutdown: bool, max_active_namespaces: usize, - config: NamespaceConfig, metadata: MetaStore, configurators: NamespaceConfigurators, + db_kind: DatabaseKind, ) -> crate::Result { tracing::trace!("Max active namespaces: {max_active_namespaces}"); let store = Cache::::builder() @@ -91,10 +91,10 @@ impl NamespaceStore { allow_lazy_creation, has_shutdown: AtomicBool::new(false), snapshot_at_shutdown, - config, schema_locks: Default::default(), broadcasters: Default::default(), configurators, + db_kind, }), }) } @@ -132,14 +132,8 @@ impl NamespaceStore { } } - Namespace::cleanup( - &self.inner.config, - &namespace, - &db_config, - prune_all, - bottomless_db_id_init, - ) - .await?; + self.cleanup(&namespace, &db_config, prune_all, bottomless_db_id_init) + .await?; tracing::info!("destroyed namespace: {namespace}"); @@ -181,15 +175,16 @@ impl NamespaceStore { let db_config = self.inner.metadata.handle(namespace.clone()); // destroy on-disk database - Namespace::cleanup( - &self.inner.config, + self.cleanup( &namespace, &db_config.get(), false, NamespaceBottomlessDbIdInit::FetchFromConfig, ) .await?; - let ns = self.make_namespace(&namespace, db_config, restore_option).await?; + let ns = self + .make_namespace(&namespace, db_config, restore_option) + .await?; lock.replace(ns); @@ -289,16 +284,17 @@ impl NamespaceStore { handle .store_and_maybe_flush(Some(to_config.into()), false) .await?; - let to_ns = Namespace::fork( - &self.inner.config, - from_ns, - from_config, - to.clone(), - handle.clone(), - timestamp, - self.clone(), - ) - .await?; + let to_ns = self + .get_configurator(&from_config.get()) + .fork( + from_ns, + from_config, + to.clone(), + handle.clone(), + timestamp, + self.clone(), + ) + .await?; to_lock.replace(to_ns); handle.flush().await?; @@ -377,23 +373,18 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - let configurator = match self.inner.config.db_kind { - DatabaseKind::Primary if config.get().is_shared_schema => { - self.inner.configurators.configure_schema()? - } - DatabaseKind::Primary => self.inner.configurators.configure_primary()?, - DatabaseKind::Replica => self.inner.configurators.configure_replica()?, - }; - let ns = configurator.setup( - &self.inner.config, - config, - restore_option, - namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ).await?; + let ns = self + .get_configurator(&config.get()) + .setup( + config, + restore_option, + namespace, + self.make_reset_cb(), + self.resolve_attach_fn(), + self.clone(), + self.broadcaster(namespace.clone()), + ) + .await?; Ok(ns) } @@ -405,7 +396,9 @@ impl NamespaceStore { restore_option: RestoreOption, ) -> crate::Result { let init = async { - let ns = self.make_namespace(namespace, db_config, restore_option).await?; + let ns = self + .make_namespace(namespace, db_config, restore_option) + .await?; Ok(Some(ns)) }; @@ -521,4 +514,26 @@ impl NamespaceStore { pub(crate) fn schema_locks(&self) -> &SchemaLocksRegistry { &self.inner.schema_locks } + + fn get_configurator(&self, db_config: &DatabaseConfig) -> &DynConfigurator { + match self.inner.db_kind { + DatabaseKind::Primary if db_config.is_shared_schema => { + self.inner.configurators.configure_schema().unwrap() + } + DatabaseKind::Primary => self.inner.configurators.configure_primary().unwrap(), + DatabaseKind::Replica => self.inner.configurators.configure_replica().unwrap(), + } + } + + async fn cleanup( + &self, + namespace: &NamespaceName, + db_config: &DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> crate::Result<()> { + self.get_configurator(db_config) + .cleanup(namespace, db_config, prune_all, bottomless_db_id_init) + .await + } } diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 17ce655064..a8195cbbd0 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -809,10 +809,11 @@ mod test { use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; use crate::namespace::configurator::{ - NamespaceConfigurators, PrimaryConfigurator, SchemaConfigurator, + BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, + SchemaConfigurator, }; use crate::namespace::meta_store::{metastore_connection_maker, MetaStore}; - use crate::namespace::{NamespaceConfig, RestoreOption}; + use crate::namespace::RestoreOption; use crate::schema::SchedulerHandle; use super::super::migration::has_pending_migration_task; @@ -833,9 +834,9 @@ mod test { false, false, 10, - config, meta_store, - NamespaceConfigurators::default(), + config, + DatabaseKind::Primary ) .await .unwrap(); @@ -912,27 +913,41 @@ mod test { assert!(!block_write.load(std::sync::atomic::Ordering::Relaxed)); } - fn make_config(migration_scheduler: SchedulerHandle, path: &Path) -> NamespaceConfig { - NamespaceConfig { - db_kind: DatabaseKind::Primary, + fn make_config(migration_scheduler: SchedulerHandle, path: &Path) -> NamespaceConfigurators { + let mut configurators = NamespaceConfigurators::empty(); + let base_config = BaseNamespaceConfig { base_path: path.to_path_buf().into(), - max_log_size: 1000000000, - max_log_duration: None, extensions: Arc::new([]), stats_sender: tokio::sync::mpsc::channel(1).0, max_response_size: 100000000000000, max_total_response_size: 100000000000, - checkpoint_interval: None, max_concurrent_connections: Arc::new(Semaphore::new(10)), max_concurrent_requests: 10000, - encryption_config: None, - channel: None, - uri: None, + }; + + let primary_config = PrimaryExtraConfig { + max_log_size: 1000000000, + max_log_duration: None, bottomless_replication: None, scripted_backup: None, + checkpoint_interval: None, + }; + + let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + + configurators.with_schema(SchemaConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), migration_scheduler, - make_wal_manager: Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), - } + )); + configurators.with_primary(PrimaryConfigurator::new( + base_config, + primary_config, + make_wal_manager.clone(), + )); + + configurators } #[tokio::test] @@ -950,9 +965,9 @@ mod test { false, false, 10, - config, meta_store, - NamespaceConfigurators::default(), + config, + DatabaseKind::Primary ) .await .unwrap(); @@ -1029,9 +1044,16 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + meta_store, + config, + DatabaseKind::Primary, + ) + .await + .unwrap(); store .with("ns".into(), |ns| { @@ -1056,10 +1078,7 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let configurators = NamespaceConfigurators::default() - .with_schema(SchemaConfigurator) - .with_primary(PrimaryConfigurator); - let store = NamespaceStore::new(false, false, 10, config, meta_store, configurators) + let store = NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) .await .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) @@ -1132,9 +1151,16 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + meta_store, + config, + DatabaseKind::Primary + ) + .await + .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); From 0647711dd81736bcb1fa1f886b3076736becc4a9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 08:30:53 +0200 Subject: [PATCH 009/121] legacy configurators --- libsql-server/src/http/admin/stats.rs | 2 + libsql-server/src/lib.rs | 425 +++++++++++++++----------- libsql-server/src/namespace/store.rs | 13 +- libsql-server/tests/cluster/mod.rs | 29 +- 4 files changed, 279 insertions(+), 190 deletions(-) diff --git a/libsql-server/src/http/admin/stats.rs b/libsql-server/src/http/admin/stats.rs index f2948d4d7b..5fce92ba0a 100644 --- a/libsql-server/src/http/admin/stats.rs +++ b/libsql-server/src/http/admin/stats.rs @@ -140,10 +140,12 @@ pub(super) async fn handle_stats( State(app_state): State>>, Path(namespace): Path, ) -> crate::Result> { + dbg!(); let stats = app_state .namespaces .stats(NamespaceName::from_string(namespace)?) .await?; + dbg!(); let resp: StatsResponse = stats.as_ref().into(); Ok(Json(resp)) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 8bd3ea4fac..4188365e03 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -4,7 +4,6 @@ use std::alloc::Layout; use std::ffi::c_void; use std::mem::{align_of, size_of}; use std::path::{Path, PathBuf}; -use std::pin::Pin; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -29,10 +28,10 @@ use auth::Auth; use config::{ AdminApiConfig, DbConfig, HeartbeatConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, }; -use futures::future::ready; use futures::Future; use http::user::UserApi; use hyper::client::HttpConnector; +use hyper::Uri; use hyper_rustls::HttpsConnector; #[cfg(feature = "durable-wal")] use libsql_storage::{DurableWalManager, LockManager}; @@ -41,10 +40,6 @@ use libsql_sys::wal::either::Either as EitherWAL; #[cfg(feature = "durable-wal")] use libsql_sys::wal::either::Either3 as EitherWAL; use libsql_sys::wal::Sqlite3WalManager; -use libsql_wal::checkpointer::LibsqlCheckpointer; -use libsql_wal::registry::WalRegistry; -use libsql_wal::storage::NoStorage; -use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; use namespace::NamespaceName; use net::Connector; @@ -55,15 +50,19 @@ use tokio::runtime::Runtime; use tokio::sync::{mpsc, Notify, Semaphore}; use tokio::task::JoinSet; use tokio::time::Duration; +use tonic::transport::Channel; use url::Url; use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; -use self::connection::connection_manager::InnerWalManager; -use self::namespace::configurator::{BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator}; +use self::namespace::configurator::{ + BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, + ReplicaConfigurator, SchemaConfigurator, +}; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; +use self::schema::SchedulerHandle; pub mod auth; mod broadcaster; @@ -424,33 +423,44 @@ where let extensions = self.db_config.validate_extensions()?; let user_auth_strategy = self.user_api_config.auth_strategy.clone(); - let service_shutdown = Arc::new(Notify::new()); - let scripted_backup = match self.db_config.snapshot_exec { Some(ref command) => { let (scripted_backup, script_backup_task) = ScriptBackupManager::new(&self.path, CommandHandler::new(command.to_string())) .await?; - join_set.spawn(script_backup_task.run()); + self.spawn_until_shutdown(&mut join_set, script_backup_task.run()); Some(scripted_backup) } None => None, }; - let (channel, uri) = match self.rpc_client_config { - Some(ref config) => { - let (channel, uri) = config.configure().await?; - (Some(channel), Some(uri)) - } - None => (None, None), + let db_kind = match self.rpc_client_config { + Some(_) => DatabaseKind::Replica, + _ => DatabaseKind::Primary, }; + let client_config = self.get_client_config().await?; let (scheduler_sender, scheduler_receiver) = mpsc::channel(128); - let (stats_sender, stats_receiver) = mpsc::channel(1024); - // chose the wal backend - let (make_wal_manager, registry_shutdown) = self.configure_wal_manager(&mut join_set)?; + let base_config = BaseNamespaceConfig { + base_path: self.path.clone(), + extensions, + stats_sender, + max_response_size: self.db_config.max_response_size, + max_total_response_size: self.db_config.max_total_response_size, + max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), + max_concurrent_requests: self.db_config.max_concurrent_requests, + }; + + let configurators = self + .make_configurators( + base_config, + scripted_backup, + scheduler_sender.into(), + client_config.clone(), + ) + .await?; let (metastore_conn_maker, meta_store_wal_manager) = metastore_connection_maker(self.meta_store_config.bottomless.clone(), &self.path) @@ -464,60 +474,6 @@ where ) .await?; - let base_config = BaseNamespaceConfig { - base_path: self.path.clone(), - extensions, - stats_sender, - max_response_size: self.db_config.max_response_size, - max_total_response_size: self.db_config.max_total_response_size, - max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), - max_concurrent_requests: self.db_config.max_concurrent_requests, - }; - - let mut configurators = NamespaceConfigurators::default(); - - let db_kind = match channel.clone().zip(uri.clone()) { - // replica mode - Some((channel, uri)) => { - let replica_configurator = ReplicaConfigurator::new( - base_config, - channel, - uri, - make_wal_manager, - ); - configurators.with_replica(replica_configurator); - DatabaseKind::Replica - } - // primary mode - None => { - let primary_config = PrimaryExtraConfig { - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - scripted_backup, - checkpoint_interval: self.db_config.checkpoint_interval, - }; - - let primary_configurator = PrimaryConfigurator::new( - base_config.clone(), - primary_config.clone(), - make_wal_manager.clone(), - ); - - let schema_configurator = SchemaConfigurator::new( - base_config.clone(), - primary_config, - make_wal_manager.clone(), - scheduler_sender.into(), - ); - - configurators.with_schema(schema_configurator); - configurators.with_primary(primary_configurator); - - DatabaseKind::Primary - }, - }; - let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, @@ -528,27 +484,9 @@ where ) .await?; - let meta_conn = metastore_conn_maker()?; - let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; - - join_set.spawn(async move { - scheduler.run(scheduler_receiver).await; - Ok(()) - }); self.spawn_monitoring_tasks(&mut join_set, stats_receiver)?; - // eagerly load the default namespace when namespaces are disabled - if self.disable_namespaces && db_kind.is_primary() { - namespace_store - .create( - NamespaceName::default(), - namespace::RestoreOption::Latest, - Default::default(), - ) - .await?; - } - // if namespaces are enabled, then bottomless must have set DB ID if !self.disable_namespaces { if let Some(bottomless) = &self.db_config.bottomless_replication { @@ -563,7 +501,7 @@ where let proxy_service = ProxyService::new(namespace_store.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds - join_set.spawn({ + self.spawn_until_shutdown(&mut join_set, { let clients = proxy_service.clients(); async move { loop { @@ -572,7 +510,8 @@ where } } }); - join_set.spawn(run_rpc_server( + + self.spawn_until_shutdown(&mut join_set, run_rpc_server( proxy_service, config.acceptor, config.tls_config, @@ -584,9 +523,28 @@ where let shutdown_timeout = self.shutdown_timeout.clone(); let shutdown = self.shutdown.clone(); + let service_shutdown = Arc::new(Notify::new()); // setup user-facing rpc services match db_kind { DatabaseKind::Primary => { + // The migration scheduler is only useful on the primary + let meta_conn = metastore_conn_maker()?; + let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; + self.spawn_until_shutdown(&mut join_set, async move { + scheduler.run(scheduler_receiver).await; + Ok(()) + }); + + if self.disable_namespaces { + namespace_store + .create( + NamespaceName::default(), + namespace::RestoreOption::Latest, + Default::default(), + ) + .await?; + } + let replication_svc = ReplicationLogService::new( namespace_store.clone(), idle_shutdown_kicker.clone(), @@ -602,7 +560,7 @@ where ); // Garbage collect proxy clients every 30 seconds - join_set.spawn({ + self.spawn_until_shutdown(&mut join_set, { let clients = proxy_svc.clients(); async move { loop { @@ -623,16 +581,19 @@ where .configure(&mut join_set); } DatabaseKind::Replica => { + dbg!(); + let (channel, uri) = client_config.clone().unwrap(); let replication_svc = - ReplicationLogProxyService::new(channel.clone().unwrap(), uri.clone().unwrap()); + ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( - channel.clone().unwrap(), - uri.clone().unwrap(), + channel, + uri, namespace_store.clone(), user_auth_strategy.clone(), self.disable_namespaces, ); + dbg!(); self.make_services( namespace_store.clone(), idle_shutdown_kicker, @@ -642,6 +603,7 @@ where service_shutdown.clone(), ) .configure(&mut join_set); + dbg!(); } }; @@ -651,7 +613,6 @@ where join_set.shutdown().await; service_shutdown.notify_waiters(); namespace_store.shutdown().await?; - registry_shutdown.await?; Ok::<_, crate::Error>(()) }; @@ -680,100 +641,200 @@ where Ok(()) } - fn setup_shutdown(&self) -> Option { - let shutdown_notify = self.shutdown.clone(); - self.idle_shutdown_timeout.map(|d| { - IdleShutdownKicker::new(d, self.initial_idle_shutdown_timeout, shutdown_notify) - }) - } - - fn configure_wal_manager( + async fn make_configurators( &self, - join_set: &mut JoinSet>, - ) -> anyhow::Result<( - Arc InnerWalManager + Sync + Send + 'static>, - Pin> + Send + Sync + 'static>>, - )> { - let wal_path = self.path.join("wals"); - let enable_libsql_wal_test = { - let is_primary = self.rpc_server_config.is_some(); - let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); - is_primary && is_libsql_wal_test - }; - let use_libsql_wal = - self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; - if !use_libsql_wal { - if wal_path.try_exists()? { - anyhow::bail!("database was previously setup to use libsql-wal"); - } - } - - if self.use_custom_wal.is_some() { - if self.db_config.bottomless_replication.is_some() { - anyhow::bail!("bottomless not supported with custom WAL"); - } - if self.rpc_client_config.is_some() { - anyhow::bail!("custom WAL not supported in replica mode"); + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + match self.use_custom_wal { + Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), + #[cfg(feature = "durable-wal")] + Some(CustomWAL::DurableWal) => self.durable_wal_configurators(), + None => { + self.legacy_configurators( + base_config, + scripted_backup, + migration_scheduler_handle, + client_config, + ) + .await } } + } - let namespace_resolver = |path: &Path| { - NamespaceName::from_string( - path.parent() - .unwrap() - .file_name() - .unwrap() - .to_str() - .unwrap() - .to_string(), - ) - .unwrap() - .into() - }; - - match self.use_custom_wal { - Some(CustomWAL::LibsqlWal) => { - let (sender, receiver) = tokio::sync::mpsc::channel(64); - let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); - let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); - join_set.spawn(async move { - checkpointer.run().await; - Ok(()) - }); + fn libsql_wal_configurators(&self) -> anyhow::Result { + todo!() + } - let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); - let shutdown_notify = self.shutdown.clone(); - let shutdown_fut = Box::pin(async move { - shutdown_notify.notified().await; - registry.shutdown().await?; - Ok(()) - }); + #[cfg(feature = "durable-wal")] + fn durable_wal_configurators(&self) -> anyhow::Result { + todo!(); + } - tracing::info!("using libsql wal"); - Ok((Arc::new(move || EitherWAL::B(wal.clone())), shutdown_fut)) + fn spawn_until_shutdown(&self, join_set: &mut JoinSet>, fut: F) + where + F: Future> + Send + 'static, + { + let shutdown = self.shutdown.clone(); + join_set.spawn(async move { + tokio::select! { + _ = shutdown.notified() => Ok(()), + ret = fut => ret } - #[cfg(feature = "durable-wal")] - Some(CustomWAL::DurableWal) => { - tracing::info!("using durable wal"); - let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); - let wal = DurableWalManager::new( - lock_manager, - namespace_resolver, - self.storage_server_address.clone(), - ); - Ok(( - Arc::new(move || EitherWAL::C(wal.clone())), - Box::pin(ready(Ok(()))), - )) + }); + } + + async fn legacy_configurators( + &self, + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + let mut configurators = NamespaceConfigurators::empty(); + + match client_config { + // replica mode + Some((channel, uri)) => { + let replica_configurator = + ReplicaConfigurator::new(base_config, channel, uri, make_wal_manager); + configurators.with_replica(replica_configurator); } + // primary mode None => { - tracing::info!("using sqlite3 wal"); - Ok(( - Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), - Box::pin(ready(Ok(()))), - )) + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; + + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); + + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + migration_scheduler_handle, + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); } } + + Ok(configurators) + } + + fn setup_shutdown(&self) -> Option { + let shutdown_notify = self.shutdown.clone(); + self.idle_shutdown_timeout.map(|d| { + IdleShutdownKicker::new(d, self.initial_idle_shutdown_timeout, shutdown_notify) + }) + } + + // fn configure_wal_manager( + // &self, + // join_set: &mut JoinSet>, + // ) -> anyhow::Result<( + // Arc InnerWalManager + Sync + Send + 'static>, + // Pin> + Send + Sync + 'static>>, + // )> { + // let wal_path = self.path.join("wals"); + // let enable_libsql_wal_test = { + // let is_primary = self.rpc_server_config.is_some(); + // let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); + // is_primary && is_libsql_wal_test + // }; + // let use_libsql_wal = + // self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; + // if !use_libsql_wal { + // if wal_path.try_exists()? { + // anyhow::bail!("database was previously setup to use libsql-wal"); + // } + // } + // + // if self.use_custom_wal.is_some() { + // if self.db_config.bottomless_replication.is_some() { + // anyhow::bail!("bottomless not supported with custom WAL"); + // } + // if self.rpc_client_config.is_some() { + // anyhow::bail!("custom WAL not supported in replica mode"); + // } + // } + // + // let namespace_resolver = |path: &Path| { + // NamespaceName::from_string( + // path.parent() + // .unwrap() + // .file_name() + // .unwrap() + // .to_str() + // .unwrap() + // .to_string(), + // ) + // .unwrap() + // .into() + // }; + // + // match self.use_custom_wal { + // Some(CustomWAL::LibsqlWal) => { + // let (sender, receiver) = tokio::sync::mpsc::channel(64); + // let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); + // let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); + // join_set.spawn(async move { + // checkpointer.run().await; + // Ok(()) + // }); + // + // let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); + // let shutdown_notify = self.shutdown.clone(); + // let shutdown_fut = Box::pin(async move { + // shutdown_notify.notified().await; + // registry.shutdown().await?; + // Ok(()) + // }); + // + // tracing::info!("using libsql wal"); + // Ok((Arc::new(move || EitherWAL::B(wal.clone())), shutdown_fut)) + // } + // #[cfg(feature = "durable-wal")] + // Some(CustomWAL::DurableWal) => { + // tracing::info!("using durable wal"); + // let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); + // let wal = DurableWalManager::new( + // lock_manager, + // namespace_resolver, + // self.storage_server_address.clone(), + // ); + // Ok(( + // Arc::new(move || EitherWAL::C(wal.clone())), + // Box::pin(ready(Ok(()))), + // )) + // } + // None => { + // tracing::info!("using sqlite3 wal"); + // Ok(( + // Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), + // Box::pin(ready(Ok(()))), + // )) + // } + // } + // } + + async fn get_client_config(&self) -> anyhow::Result> { + match self.rpc_client_config { + Some(ref config) => Ok(Some(config.configure().await?)), + None => Ok(None), + } } } diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index a78e4f59b0..b2b5d33032 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -327,6 +327,7 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { + dbg!(); if namespace != NamespaceName::default() && !self.inner.metadata.exists(&namespace) && !self.inner.allow_lazy_creation @@ -334,6 +335,7 @@ impl NamespaceStore { return Err(Error::NamespaceDoesntExist(namespace.to_string())); } + dbg!(); let f = { let name = namespace.clone(); move |ns: NamespaceEntry| async move { @@ -346,7 +348,9 @@ impl NamespaceStore { } }; + dbg!(); let handle = self.inner.metadata.handle(namespace.to_owned()); + dbg!(); f(self .load_namespace(&namespace, handle, RestoreOption::Latest) .await?) @@ -373,6 +377,7 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { + dbg!(); let ns = self .get_configurator(&config.get()) .setup( @@ -386,6 +391,7 @@ impl NamespaceStore { ) .await?; + dbg!(); Ok(ns) } @@ -395,13 +401,17 @@ impl NamespaceStore { db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { + dbg!(); let init = async { + dbg!(); let ns = self .make_namespace(namespace, db_config, restore_option) .await?; + dbg!(); Ok(Some(ns)) }; + dbg!(); let before_load = Instant::now(); let ns = self .inner @@ -410,7 +420,8 @@ impl NamespaceStore { namespace.clone(), init.map_ok(|ns| Arc::new(RwLock::new(ns))), ) - .await?; + .await.map_err(|e| dbg!(e))?; + dbg!(); NAMESPACE_LOAD_LATENCY.record(before_load.elapsed()); Ok(ns) diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 1171d4a5d0..8f214bd05e 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -149,23 +149,29 @@ fn sync_many_replica() { let mut sim = Builder::new() .simulation_duration(Duration::from_secs(1000)) .build(); + dbg!(); make_cluster(&mut sim, NUM_REPLICA, true); + dbg!(); sim.client("client", async { let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; + dbg!(); conn.execute("create table test (x)", ()).await?; + dbg!(); conn.execute("insert into test values (42)", ()).await?; + dbg!(); async fn get_frame_no(url: &str) -> Option { let client = Client::new(); + dbg!(); Some( - client - .get(url) - .await - .unwrap() - .json::() - .await + dbg!(client + .get(url) + .await + .unwrap() + .json::() + .await) .unwrap() .get("replication_index")? .as_u64() @@ -173,6 +179,7 @@ fn sync_many_replica() { ) } + dbg!(); let primary_fno = loop { if let Some(fno) = get_frame_no("http://primary:9090/v1/namespaces/default/stats").await { @@ -180,13 +187,15 @@ fn sync_many_replica() { } }; + dbg!(); // wait for all replicas to sync let mut join_set = JoinSet::new(); for i in 0..NUM_REPLICA { join_set.spawn(async move { let uri = format!("http://replica{i}:9090/v1/namespaces/default/stats"); + dbg!(); loop { - if let Some(replica_fno) = get_frame_no(&uri).await { + if let Some(replica_fno) = dbg!(get_frame_no(&uri).await) { if replica_fno == primary_fno { break; } @@ -196,8 +205,10 @@ fn sync_many_replica() { }); } + dbg!(); while join_set.join_next().await.is_some() {} + dbg!(); for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), @@ -212,8 +223,10 @@ fn sync_many_replica() { )); } + dbg!(); let client = Client::new(); + dbg!(); let stats = client .get("http://primary:9090/v1/namespaces/default/stats") .await? @@ -221,12 +234,14 @@ fn sync_many_replica() { .await .unwrap(); + dbg!(); let stat = stats .get("embedded_replica_frames_replicated") .unwrap() .as_u64() .unwrap(); + dbg!(); assert_eq!(stat, 0); Ok(()) From 76558e8df4d1927e7bd9fee7da4821f1a07a87f6 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 13:33:23 +0400 Subject: [PATCH 010/121] fix behaviour of VACUUM for vector indices to make rowid consistent between shadow tables and base table --- libsql-sqlite3/src/vacuum.c | 26 ++++++++++++++++++ libsql-sqlite3/src/vectorIndex.c | 28 +------------------- libsql-sqlite3/test/libsql_vector_index.test | 9 +++++-- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/libsql-sqlite3/src/vacuum.c b/libsql-sqlite3/src/vacuum.c index c0ae4bc1e1..d927a8d5a6 100644 --- a/libsql-sqlite3/src/vacuum.c +++ b/libsql-sqlite3/src/vacuum.c @@ -17,6 +17,10 @@ #include "sqliteInt.h" #include "vdbeInt.h" +#ifndef SQLITE_OMIT_VECTOR +#include "vectorIndexInt.h" +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -294,6 +298,27 @@ SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -305,6 +330,7 @@ SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 78266ed462..d520419a41 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -49,11 +49,6 @@ ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -772,10 +767,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -786,10 +777,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -799,7 +786,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -817,10 +804,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -970,7 +953,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -1055,10 +1037,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -1078,10 +1056,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 19d31ba19c..7308b2d93f 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -236,12 +236,17 @@ do_execsql_test vector-attach { do_execsql_test vector-vacuum { CREATE TABLE t_vacuum ( emb FLOAT32(2) ); - INSERT INTO t_vacuum VALUES (vector('[1,2]')), (vector('[3,4]')); + INSERT INTO t_vacuum VALUES (vector('[1,2]')), (vector('[3,4]')), (vector('[5,6]')); CREATE INDEX t_vacuum_idx ON t_vacuum(libsql_vector_idx(emb)); VACUUM; SELECT COUNT(*) FROM t_vacuum; SELECT COUNT(*) FROM t_vacuum_idx_shadow; -} {2 2} + DELETE FROM t_vacuum WHERE rowid = 2; + VACUUM; + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[1,2]'), 3); + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[5,6]'), 3); + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[3,4]'), 3); +} {3 3 1 2 2 1 2 1} do_execsql_test vector-many-columns { CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) ); From 853143d77b94b26653a87d45134e2b77476096b3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 13:48:21 +0400 Subject: [PATCH 011/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 55 ++++++++++--------- libsql-ffi/bundled/bindings/bindgen.rs | 26 ++++++++- libsql-ffi/bundled/src/sqlite3.c | 55 ++++++++++--------- 3 files changed, 79 insertions(+), 57 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index c22f35046f..ec692baa53 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -69,6 +69,7 @@ ** src/test2.c ** src/test3.c ** src/test8.c +** src/vacuum.c ** src/vdbe.c ** src/vdbeInt.h ** src/vdbeapi.c @@ -155950,6 +155951,10 @@ SQLITE_PRIVATE void sqlite3UpsertDoUpdate( /* #include "sqliteInt.h" */ /* #include "vdbeInt.h" */ +#ifndef SQLITE_OMIT_VECTOR +/* #include "vectorIndexInt.h" */ +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -156227,6 +156232,27 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -156238,6 +156264,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; @@ -213656,11 +213683,6 @@ int vectorF64ParseSqliteBlob( ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -214379,10 +214401,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -214393,10 +214411,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -214406,7 +214420,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -214424,10 +214438,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -214577,7 +214587,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -214662,10 +214671,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -214685,10 +214690,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; diff --git a/libsql-ffi/bundled/bindings/bindgen.rs b/libsql-ffi/bundled/bindings/bindgen.rs index e11d453281..cc73807f33 100644 --- a/libsql-ffi/bundled/bindings/bindgen.rs +++ b/libsql-ffi/bundled/bindings/bindgen.rs @@ -24,10 +24,10 @@ extern "C" { } pub const __GNUC_VA_LIST: i32 = 1; -pub const SQLITE_VERSION: &[u8; 7] = b"3.44.0\0"; -pub const SQLITE_VERSION_NUMBER: i32 = 3044000; +pub const SQLITE_VERSION: &[u8; 7] = b"3.45.1\0"; +pub const SQLITE_VERSION_NUMBER: i32 = 3045001; pub const SQLITE_SOURCE_ID: &[u8; 85] = - b"2023-11-01 11:23:50 17129ba1ff7f0daf37100ee82d507aef7827cf38de1866e2633096ae6ad8alt1\0"; + b"2024-01-30 16:01:20 e876e51a0ed5c5b3126f52e532044363a014bc594cfefa87ffb5b82257ccalt1\0"; pub const LIBSQL_VERSION: &[u8; 6] = b"0.2.3\0"; pub const SQLITE_OK: i32 = 0; pub const SQLITE_ERROR: i32 = 1; @@ -356,6 +356,7 @@ pub const SQLITE_DETERMINISTIC: i32 = 2048; pub const SQLITE_DIRECTONLY: i32 = 524288; pub const SQLITE_SUBTYPE: i32 = 1048576; pub const SQLITE_INNOCUOUS: i32 = 2097152; +pub const SQLITE_RESULT_SUBTYPE: i32 = 16777216; pub const SQLITE_WIN32_DATA_DIRECTORY_TYPE: i32 = 1; pub const SQLITE_WIN32_TEMP_DIRECTORY_TYPE: i32 = 2; pub const SQLITE_TXN_NONE: i32 = 0; @@ -408,6 +409,7 @@ pub const SQLITE_TESTCTRL_PENDING_BYTE: i32 = 11; pub const SQLITE_TESTCTRL_ASSERT: i32 = 12; pub const SQLITE_TESTCTRL_ALWAYS: i32 = 13; pub const SQLITE_TESTCTRL_RESERVE: i32 = 14; +pub const SQLITE_TESTCTRL_JSON_SELFCHECK: i32 = 14; pub const SQLITE_TESTCTRL_OPTIMIZATIONS: i32 = 15; pub const SQLITE_TESTCTRL_ISKEYWORD: i32 = 16; pub const SQLITE_TESTCTRL_SCRATCHMALLOC: i32 = 17; @@ -3133,6 +3135,24 @@ pub struct Fts5ExtensionApi { piCol: *mut ::std::os::raw::c_int, ), >, + pub xQueryToken: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut Fts5Context, + iPhrase: ::std::os::raw::c_int, + iToken: ::std::os::raw::c_int, + ppToken: *mut *const ::std::os::raw::c_char, + pnToken: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int, + >, + pub xInstToken: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut Fts5Context, + iIdx: ::std::os::raw::c_int, + iToken: ::std::os::raw::c_int, + arg2: *mut *const ::std::os::raw::c_char, + arg3: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int, + >, } #[repr(C)] #[derive(Debug, Copy, Clone)] diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index c22f35046f..ec692baa53 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -69,6 +69,7 @@ ** src/test2.c ** src/test3.c ** src/test8.c +** src/vacuum.c ** src/vdbe.c ** src/vdbeInt.h ** src/vdbeapi.c @@ -155950,6 +155951,10 @@ SQLITE_PRIVATE void sqlite3UpsertDoUpdate( /* #include "sqliteInt.h" */ /* #include "vdbeInt.h" */ +#ifndef SQLITE_OMIT_VECTOR +/* #include "vectorIndexInt.h" */ +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -156227,6 +156232,27 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -156238,6 +156264,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; @@ -213656,11 +213683,6 @@ int vectorF64ParseSqliteBlob( ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -214379,10 +214401,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -214393,10 +214411,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -214406,7 +214420,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -214424,10 +214438,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -214577,7 +214587,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -214662,10 +214671,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -214685,10 +214690,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; From 2115277f8c6c8ab8494dc70fa1413d30b7f362f4 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 14:51:32 +0400 Subject: [PATCH 012/121] fix bug --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 2 +- libsql-ffi/bundled/src/sqlite3.c | 2 +- libsql-sqlite3/src/vacuum.c | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index ec692baa53..c25985af8a 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -156248,7 +156248,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index ec692baa53..c25985af8a 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -156248,7 +156248,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } diff --git a/libsql-sqlite3/src/vacuum.c b/libsql-sqlite3/src/vacuum.c index d927a8d5a6..f8e848aca6 100644 --- a/libsql-sqlite3/src/vacuum.c +++ b/libsql-sqlite3/src/vacuum.c @@ -314,7 +314,7 @@ SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } From b12431c33a172ba0c96f35cfd35e36ce931b5c00 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 13:30:21 +0200 Subject: [PATCH 013/121] configure durable wal --- libsql-server/src/lib.rs | 98 ++++++++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 4188365e03..9ee8e3b908 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -55,6 +55,7 @@ use url::Url; use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; +use self::connection::connection_manager::InnerWalManager; use self::namespace::configurator::{ BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator, @@ -336,7 +337,8 @@ where config.heartbeat_url.as_deref().unwrap_or(""), config.heartbeat_period, ); - join_set.spawn({ + + self.spawn_until_shutdown_on(join_set, { let heartbeat_auth = config.heartbeat_auth.clone(); let heartbeat_period = config.heartbeat_period; let heartbeat_url = if let Some(url) = &config.heartbeat_url { @@ -428,7 +430,7 @@ where let (scripted_backup, script_backup_task) = ScriptBackupManager::new(&self.path, CommandHandler::new(command.to_string())) .await?; - self.spawn_until_shutdown(&mut join_set, script_backup_task.run()); + self.spawn_until_shutdown_on(&mut join_set, script_backup_task.run()); Some(scripted_backup) } None => None, @@ -484,7 +486,6 @@ where ) .await?; - self.spawn_monitoring_tasks(&mut join_set, stats_receiver)?; // if namespaces are enabled, then bottomless must have set DB ID @@ -501,7 +502,7 @@ where let proxy_service = ProxyService::new(namespace_store.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds - self.spawn_until_shutdown(&mut join_set, { + self.spawn_until_shutdown_on(&mut join_set, { let clients = proxy_service.clients(); async move { loop { @@ -511,14 +512,17 @@ where } }); - self.spawn_until_shutdown(&mut join_set, run_rpc_server( - proxy_service, - config.acceptor, - config.tls_config, - idle_shutdown_kicker.clone(), - namespace_store.clone(), - self.disable_namespaces, - )); + self.spawn_until_shutdown_on( + &mut join_set, + run_rpc_server( + proxy_service, + config.acceptor, + config.tls_config, + idle_shutdown_kicker.clone(), + namespace_store.clone(), + self.disable_namespaces, + ), + ); } let shutdown_timeout = self.shutdown_timeout.clone(); @@ -530,7 +534,7 @@ where // The migration scheduler is only useful on the primary let meta_conn = metastore_conn_maker()?; let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; - self.spawn_until_shutdown(&mut join_set, async move { + self.spawn_until_shutdown_on(&mut join_set, async move { scheduler.run(scheduler_receiver).await; Ok(()) }); @@ -560,7 +564,7 @@ where ); // Garbage collect proxy clients every 30 seconds - self.spawn_until_shutdown(&mut join_set, { + self.spawn_until_shutdown_on(&mut join_set, { let clients = proxy_svc.clients(); async move { loop { @@ -583,8 +587,7 @@ where DatabaseKind::Replica => { dbg!(); let (channel, uri) = client_config.clone().unwrap(); - let replication_svc = - ReplicationLogProxyService::new(channel.clone(), uri.clone()); + let replication_svc = ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( channel, uri, @@ -651,7 +654,12 @@ where match self.use_custom_wal { Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), #[cfg(feature = "durable-wal")] - Some(CustomWAL::DurableWal) => self.durable_wal_configurators(), + Some(CustomWAL::DurableWal) => self.durable_wal_configurators( + base_config, + scripted_backup, + migration_scheduler_handle, + client_config, + ), None => { self.legacy_configurators( base_config, @@ -669,11 +677,44 @@ where } #[cfg(feature = "durable-wal")] - fn durable_wal_configurators(&self) -> anyhow::Result { - todo!(); + fn durable_wal_configurators( + &self, + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + tracing::info!("using durable wal"); + let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); + let namespace_resolver = |path: &Path| { + NamespaceName::from_string( + path.parent() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + .to_string(), + ) + .unwrap() + .into() + }; + let wal = DurableWalManager::new( + lock_manager, + namespace_resolver, + self.storage_server_address.clone(), + ); + let make_wal_manager = Arc::new(move || EitherWAL::C(wal.clone())); + self.configurators_common( + client_config, + base_config, + make_wal_manager, + scripted_backup, + migration_scheduler_handle, + ) } - fn spawn_until_shutdown(&self, join_set: &mut JoinSet>, fut: F) + fn spawn_until_shutdown_on(&self, join_set: &mut JoinSet>, fut: F) where F: Future> + Send + 'static, { @@ -694,6 +735,23 @@ where client_config: Option<(Channel, Uri)>, ) -> anyhow::Result { let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + self.configurators_common( + client_config, + base_config, + make_wal_manager, + scripted_backup, + migration_scheduler_handle, + ) + } + + fn configurators_common( + &self, + client_config: Option<(Channel, Uri)>, + base_config: BaseNamespaceConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + ) -> anyhow::Result { let mut configurators = NamespaceConfigurators::empty(); match client_config { From 066f1527572d3e3012457e588e9d28e8e74656ec Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 14:37:10 +0200 Subject: [PATCH 014/121] configure libsql_wal --- libsql-server/src/lib.rs | 213 ++++++++++++++++++++++++++++++--------- 1 file changed, 165 insertions(+), 48 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 9ee8e3b908..f5788dcebb 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -28,6 +28,7 @@ use auth::Auth; use config::{ AdminApiConfig, DbConfig, HeartbeatConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, }; +use futures::future::{pending, ready}; use futures::Future; use http::user::UserApi; use hyper::client::HttpConnector; @@ -40,6 +41,10 @@ use libsql_sys::wal::either::Either as EitherWAL; #[cfg(feature = "durable-wal")] use libsql_sys::wal::either::Either3 as EitherWAL; use libsql_sys::wal::Sqlite3WalManager; +use libsql_wal::checkpointer::LibsqlCheckpointer; +use libsql_wal::registry::WalRegistry; +use libsql_wal::storage::NoStorage; +use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; use namespace::NamespaceName; use net::Connector; @@ -458,9 +463,10 @@ where let configurators = self .make_configurators( base_config, - scripted_backup, - scheduler_sender.into(), client_config.clone(), + &mut join_set, + scheduler_sender.into(), + scripted_backup, ) .await?; @@ -596,7 +602,6 @@ where self.disable_namespaces, ); - dbg!(); self.make_services( namespace_store.clone(), idle_shutdown_kicker, @@ -647,42 +652,125 @@ where async fn make_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + join_set: &mut JoinSet>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { + let wal_path = base_config.base_path.join("wals"); + let enable_libsql_wal_test = { + let is_primary = self.rpc_server_config.is_some(); + let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); + is_primary && is_libsql_wal_test + }; + let use_libsql_wal = + self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; + if !use_libsql_wal { + if wal_path.try_exists()? { + anyhow::bail!("database was previously setup to use libsql-wal"); + } + } + + if self.use_custom_wal.is_some() { + if self.db_config.bottomless_replication.is_some() { + anyhow::bail!("bottomless not supported with custom WAL"); + } + if self.rpc_client_config.is_some() { + anyhow::bail!("custom WAL not supported in replica mode"); + } + } + match self.use_custom_wal { - Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), + Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators( + base_config, + client_config, + join_set, + migration_scheduler_handle, + scripted_backup, + wal_path, + ), #[cfg(feature = "durable-wal")] Some(CustomWAL::DurableWal) => self.durable_wal_configurators( base_config, - scripted_backup, - migration_scheduler_handle, client_config, + migration_scheduler_handle, + scripted_backup, ), None => { self.legacy_configurators( base_config, - scripted_backup, - migration_scheduler_handle, client_config, + migration_scheduler_handle, + scripted_backup, ) .await } } } - fn libsql_wal_configurators(&self) -> anyhow::Result { - todo!() + fn libsql_wal_configurators( + &self, + base_config: BaseNamespaceConfig, + client_config: Option<(Channel, Uri)>, + join_set: &mut JoinSet>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, + wal_path: PathBuf, + ) -> anyhow::Result { + tracing::info!("using libsql wal"); + let (sender, receiver) = tokio::sync::mpsc::channel(64); + let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); + let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); + self.spawn_until_shutdown_on(join_set, async move { + checkpointer.run().await; + Ok(()) + }); + + let namespace_resolver = |path: &Path| { + NamespaceName::from_string( + path.parent() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + .to_string(), + ) + .unwrap() + .into() + }; + let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); + + self.spawn_until_shutdown_with_teardown(join_set, pending(), async move { + registry.shutdown().await?; + Ok(()) + }); + + let make_wal_manager = Arc::new(move || EitherWAL::B(wal.clone())); + let mut configurators = NamespaceConfigurators::empty(); + + match client_config { + Some(_) => todo!("configure replica"), + // configure primary + None => self.configure_primary_common( + base_config, + &mut configurators, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ), + } + + Ok(configurators) } #[cfg(feature = "durable-wal")] fn durable_wal_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { tracing::info!("using durable wal"); let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); @@ -706,22 +794,37 @@ where ); let make_wal_manager = Arc::new(move || EitherWAL::C(wal.clone())); self.configurators_common( - client_config, base_config, + client_config, make_wal_manager, - scripted_backup, migration_scheduler_handle, + scripted_backup, ) } fn spawn_until_shutdown_on(&self, join_set: &mut JoinSet>, fut: F) where F: Future> + Send + 'static, + { + self.spawn_until_shutdown_with_teardown(join_set, fut, ready(Ok(()))) + } + + /// run the passed future until shutdown is called, then call the passed teardown future + fn spawn_until_shutdown_with_teardown( + &self, + join_set: &mut JoinSet>, + fut: F, + teardown: T, + ) where + F: Future> + Send + 'static, + T: Future> + Send + 'static, { let shutdown = self.shutdown.clone(); join_set.spawn(async move { tokio::select! { - _ = shutdown.notified() => Ok(()), + _ = shutdown.notified() => { + teardown.await + }, ret = fut => ret } }); @@ -730,30 +833,29 @@ where async fn legacy_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); self.configurators_common( - client_config, base_config, + client_config, make_wal_manager, - scripted_backup, migration_scheduler_handle, + scripted_backup, ) } fn configurators_common( &self, - client_config: Option<(Channel, Uri)>, base_config: BaseNamespaceConfig, + client_config: Option<(Channel, Uri)>, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, - scripted_backup: Option, migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { let mut configurators = NamespaceConfigurators::empty(); - match client_config { // replica mode Some((channel, uri)) => { @@ -762,34 +864,49 @@ where configurators.with_replica(replica_configurator); } // primary mode - None => { - let primary_config = PrimaryExtraConfig { - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - scripted_backup, - checkpoint_interval: self.db_config.checkpoint_interval, - }; + None => self.configure_primary_common( + base_config, + &mut configurators, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ), + } - let primary_configurator = PrimaryConfigurator::new( - base_config.clone(), - primary_config.clone(), - make_wal_manager.clone(), - ); + Ok(configurators) + } - let schema_configurator = SchemaConfigurator::new( - base_config.clone(), - primary_config, - make_wal_manager.clone(), - migration_scheduler_handle, - ); + fn configure_primary_common( + &self, + base_config: BaseNamespaceConfig, + configurators: &mut NamespaceConfigurators, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, + ) { + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; - configurators.with_schema(schema_configurator); - configurators.with_primary(primary_configurator); - } - } + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); - Ok(configurators) + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + migration_scheduler_handle, + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); } fn setup_shutdown(&self) -> Option { From b5dba7241531d3c5bb58a093189b72b9ea47fb25 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 16:52:23 +0200 Subject: [PATCH 015/121] partial implmentation of LibsqlWalReplicationConfigurator --- .../configurator/libsql_wal_replica.rs | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 libsql-server/src/namespace/configurator/libsql_wal_replica.rs diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs new file mode 100644 index 0000000000..6b2519cf33 --- /dev/null +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -0,0 +1,138 @@ +use std::pin::Pin; +use std::future::Future; +use std::sync::Arc; + +use chrono::prelude::NaiveDateTime; +use hyper::Uri; +use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; +use libsql_wal::io::StdIO; +use libsql_wal::registry::WalRegistry; +use libsql_wal::storage::NoStorage; +use tokio::task::JoinSet; +use tonic::transport::Channel; + +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; +use crate::connection::write_proxy::MakeWriteProxyConn; +use crate::connection::MakeConnection; +use crate::database::{Database, ReplicaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_stats; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, +}; +use crate::DEFAULT_AUTO_CHECKPOINT; + +use super::{BaseNamespaceConfig, ConfigureNamespace}; + +pub struct LibsqlWalReplicaConfigurator { + base: BaseNamespaceConfig, + registry: Arc>, + uri: Uri, + channel: Channel, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl ConfigureNamespace for LibsqlWalReplicaConfigurator { + fn setup<'a>( + &'a self, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + tracing::debug!("creating replica namespace"); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); + let channel = self.channel.clone(); + let uri = self.uri.clone(); + + let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + // TODO! setup replication + + let mut join_set = JoinSet::new(); + let namespace = name.clone(); + + let stats = make_stats( + &db_path, + &mut join_set, + db_config.clone(), + self.base.stats_sender.clone(), + name.clone(), + applied_frame_no_receiver.clone(), + ) + .await?; + + let connection_maker = MakeWriteProxyConn::new( + db_path.clone(), + self.base.extensions.clone(), + channel.clone(), + uri.clone(), + stats.clone(), + broadcaster, + db_config.clone(), + applied_frame_no_receiver, + self.base.max_response_size, + self.base.max_total_response_size, + primary_current_replication_index, + None, + resolve_attach_path, + self.make_wal_manager.clone(), + ) + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); + + Ok(Namespace { + tasks: join_set, + db: Database::Replica(ReplicaDatabase { + connection_maker: Arc::new(connection_maker), + }), + name: name.clone(), + stats, + db_config_store: db_config, + path: db_path.into(), + }) + }) + } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + _db_config: &DatabaseConfig, + _prune_all: bool, + _bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let ns_path = self.base.base_path.join("dbs").join(namespace.as_str()); + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + Ok(()) + }) + } + + fn fork<'a>( + &'a self, + _from_ns: &'a Namespace, + _from_config: MetaStoreHandle, + _to_ns: NamespaceName, + _to_config: MetaStoreHandle, + _timestamp: Option, + _store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(std::future::ready(Err(crate::Error::Fork( + super::fork::ForkError::ForkReplica, + )))) + } +} From ded5ba7f859f6f6b94f1a1b6614e31521a591f01 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 16:54:14 +0200 Subject: [PATCH 016/121] fmt + remove dbgs --- libsql-server/src/http/admin/stats.rs | 2 - libsql-server/src/lib.rs | 2 - .../src/namespace/configurator/fork.rs | 7 ++- .../src/namespace/configurator/helpers.rs | 60 +++++++++--------- .../configurator/libsql_wal_replica.rs | 18 +++--- .../src/namespace/configurator/mod.rs | 18 ++++-- .../src/namespace/configurator/primary.rs | 13 ++-- .../src/namespace/configurator/schema.rs | 29 ++++++--- libsql-server/src/namespace/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 13 +--- libsql-server/src/schema/scheduler.rs | 63 ++++++------------- libsql-server/tests/cluster/mod.rs | 29 +++------ 12 files changed, 111 insertions(+), 145 deletions(-) diff --git a/libsql-server/src/http/admin/stats.rs b/libsql-server/src/http/admin/stats.rs index 5fce92ba0a..f2948d4d7b 100644 --- a/libsql-server/src/http/admin/stats.rs +++ b/libsql-server/src/http/admin/stats.rs @@ -140,12 +140,10 @@ pub(super) async fn handle_stats( State(app_state): State>>, Path(namespace): Path, ) -> crate::Result> { - dbg!(); let stats = app_state .namespaces .stats(NamespaceName::from_string(namespace)?) .await?; - dbg!(); let resp: StatsResponse = stats.as_ref().into(); Ok(Json(resp)) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index f5788dcebb..d26921dd00 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -591,7 +591,6 @@ where .configure(&mut join_set); } DatabaseKind::Replica => { - dbg!(); let (channel, uri) = client_config.clone().unwrap(); let replication_svc = ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( @@ -611,7 +610,6 @@ where service_shutdown.clone(), ) .configure(&mut join_set); - dbg!(); } }; diff --git a/libsql-server/src/namespace/configurator/fork.rs b/libsql-server/src/namespace/configurator/fork.rs index 26a0b99b61..03f2ac03d8 100644 --- a/libsql-server/src/namespace/configurator/fork.rs +++ b/libsql-server/src/namespace/configurator/fork.rs @@ -58,7 +58,7 @@ pub(super) async fn fork( Database::Schema(db) => db.wal_wrapper.wrapper().logger(), _ => { return Err(crate::Error::Fork(ForkError::Internal(anyhow::Error::msg( - "Invalid source database type for fork", + "Invalid source database type for fork", )))); } }; @@ -114,7 +114,7 @@ pub struct ForkTask { pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub store: NamespaceStore + pub store: NamespaceStore, } pub struct PointInTimeRestore { @@ -156,7 +156,8 @@ impl ForkTask { let dest_path = self.base_path.join("dbs").join(self.to_namespace.as_str()); tokio::fs::rename(temp_dir.path(), dest_path).await?; - self.store.make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) + self.store + .make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) .await .map_err(|e| ForkError::CreateNamespace(Box::new(e))) } diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs index f43fa8a192..a5a4c5121d 100644 --- a/libsql-server/src/namespace/configurator/helpers.rs +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -6,26 +6,29 @@ use std::time::Duration; use anyhow::Context as _; use bottomless::replicator::Options; use bytes::Bytes; +use enclose::enclose; use futures::Stream; use libsql_sys::wal::Sqlite3WalManager; use tokio::io::AsyncBufReadExt as _; use tokio::sync::watch; use tokio::task::JoinSet; use tokio_util::io::StreamReader; -use enclose::enclose; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::connection::libsql::{open_conn, MakeLibSqlConn}; use crate::connection::{Connection as _, MakeConnection as _}; +use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; use crate::error::LoadDumpError; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; +use crate::namespace::{ + NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, + RestoreOption, +}; use crate::replication::{FrameNo, ReplicationLogger}; use crate::stats::Stats; -use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, RestoreOption}; -use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; -use crate::namespace::meta_store::MetaStoreHandle; -use crate::namespace::broadcasters::BroadcasterHandle; -use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; use super::{BaseNamespaceConfig, PrimaryExtraConfig}; @@ -74,8 +77,7 @@ pub(super) async fn make_primary_connection_maker( tracing::debug!("Checkpointed before initializing bottomless"); let options = make_bottomless_options(options, bottomless_db_id, name.clone()); let (replicator, did_recover) = - init_bottomless_replicator(db_path.join("data"), options, &restore_option) - .await?; + init_bottomless_replicator(db_path.join("data"), options, &restore_option).await?; tracing::debug!("Completed init of bottomless replicator"); is_dirty |= did_recover; Some(replicator) @@ -93,14 +95,14 @@ pub(super) async fn make_primary_connection_maker( }; let logger = Arc::new(ReplicationLogger::open( - &db_path, - primary_config.max_log_size, - primary_config.max_log_duration, - is_dirty, - auto_checkpoint, - primary_config.scripted_backup.clone(), - name.clone(), - None, + &db_path, + primary_config.max_log_size, + primary_config.max_log_duration, + is_dirty, + auto_checkpoint, + primary_config.scripted_backup.clone(), + name.clone(), + None, )?); tracing::debug!("sending stats"); @@ -113,7 +115,7 @@ pub(super) async fn make_primary_connection_maker( name.clone(), logger.new_frame_notifier.subscribe(), ) - .await?; + .await?; tracing::debug!("Making replication wal wrapper"); let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); @@ -136,13 +138,13 @@ pub(super) async fn make_primary_connection_maker( resolve_attach_path, make_wal_manager.clone(), ) - .await? - .throttled( - base_config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - base_config.max_total_response_size, - base_config.max_concurrent_requests, - ); + .await? + .throttled( + base_config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + base_config.max_total_response_size, + base_config.max_concurrent_requests, + ); tracing::debug!("Completed opening libsql connection"); @@ -356,10 +358,7 @@ pub(super) async fn make_stats( } }); - join_set.spawn(run_storage_monitor( - db_path.into(), - Arc::downgrade(&stats), - )); + join_set.spawn(run_storage_monitor(db_path.into(), Arc::downgrade(&stats))); tracing::debug!("done sending stats, and creating bg tasks"); @@ -369,10 +368,7 @@ pub(super) async fn make_stats( // Periodically check the storage used by the database and save it in the Stats structure. // TODO: Once we have a separate fiber that does WAL checkpoints, running this routine // right after checkpointing is exactly where it should be done. -async fn run_storage_monitor( - db_path: PathBuf, - stats: Weak, -) -> anyhow::Result<()> { +async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Result<()> { // on initialization, the database file doesn't exist yet, so we wait a bit for it to be // created tokio::time::sleep(Duration::from_secs(1)).await; diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs index 6b2519cf33..f26738ec2a 100644 --- a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -1,5 +1,5 @@ -use std::pin::Pin; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use chrono::prelude::NaiveDateTime; @@ -66,7 +66,7 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { name.clone(), applied_frame_no_receiver.clone(), ) - .await?; + .await?; let connection_maker = MakeWriteProxyConn::new( db_path.clone(), @@ -84,13 +84,13 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { resolve_attach_path, self.make_wal_manager.clone(), ) - .await? - .throttled( - self.base.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - self.base.max_total_response_size, - self.base.max_concurrent_requests, - ); + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); Ok(Namespace { tasks: join_set, diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index e5db335ff6..9122fc18de 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -13,13 +13,17 @@ use crate::StatsSender; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use super::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, +}; +pub mod fork; mod helpers; +mod libsql_wal_replica; mod primary; mod replica; mod schema; -pub mod fork; pub use primary::PrimaryConfigurator; pub use replica::ReplicaConfigurator; @@ -68,12 +72,18 @@ impl NamespaceConfigurators { } } - pub fn with_primary(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_primary( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_replica( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { self.replica_configurator = Some(Box::new(c)); self } diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index 4351f6a3ac..6c245a6e8f 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -12,8 +12,8 @@ use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::configurator::helpers::make_primary_connection_maker; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, - ResetCb, ResolveNamespacePathFn, RestoreOption, + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, }; use crate::run_periodic_checkpoint; use crate::schema::{has_pending_migration_task, setup_migration_table}; @@ -168,7 +168,8 @@ impl ConfigureNamespace for PrimaryConfigurator { db_config, prune_all, bottomless_db_id_init, - ).await + ) + .await }) } @@ -186,10 +187,10 @@ impl ConfigureNamespace for PrimaryConfigurator { from_config, to_ns, to_config, - timestamp, + timestamp, store, &self.primary_config, - self.base.base_path.clone())) + self.base.base_path.clone(), + )) } } - diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index e55c706fec..98e679513a 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -6,12 +6,11 @@ use tokio::task::JoinSet; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::database::{Database, SchemaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceName, NamespaceStore, - ResetCb, ResolveNamespacePathFn, RestoreOption, + Namespace, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, }; -use crate::namespace::broadcasters::BroadcasterHandle; use crate::schema::SchedulerHandle; use super::helpers::{cleanup_primary, make_primary_connection_maker}; @@ -25,8 +24,18 @@ pub struct SchemaConfigurator { } impl SchemaConfigurator { - pub fn new(base: BaseNamespaceConfig, primary_config: PrimaryExtraConfig, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, migration_scheduler: SchedulerHandle) -> Self { - Self { base, primary_config, make_wal_manager, migration_scheduler } + pub fn new( + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler: SchedulerHandle, + ) -> Self { + Self { + base, + primary_config, + make_wal_manager, + migration_scheduler, + } } } @@ -58,7 +67,7 @@ impl ConfigureNamespace for SchemaConfigurator { &mut join_set, resolve_attach_path, broadcaster, - self.make_wal_manager.clone() + self.make_wal_manager.clone(), ) .await?; @@ -94,7 +103,8 @@ impl ConfigureNamespace for SchemaConfigurator { db_config, prune_all, bottomless_db_id_init, - ).await + ) + .await }) } @@ -112,9 +122,10 @@ impl ConfigureNamespace for SchemaConfigurator { from_config, to_ns, to_config, - timestamp, + timestamp, store, &self.primary_config, - self.base.base_path.clone())) + self.base.base_path.clone(), + )) } } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 7cfa6b351c..2a2e3eb211 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -19,12 +19,12 @@ pub use self::name::NamespaceName; pub use self::store::NamespaceStore; pub mod broadcasters; +pub(crate) mod configurator; pub mod meta_store; mod name; pub mod replication_wal; mod schema_lock; mod store; -pub(crate) mod configurator; pub type ResetCb = Box; pub type ResolveNamespacePathFn = diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index b2b5d33032..a78e4f59b0 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -327,7 +327,6 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { - dbg!(); if namespace != NamespaceName::default() && !self.inner.metadata.exists(&namespace) && !self.inner.allow_lazy_creation @@ -335,7 +334,6 @@ impl NamespaceStore { return Err(Error::NamespaceDoesntExist(namespace.to_string())); } - dbg!(); let f = { let name = namespace.clone(); move |ns: NamespaceEntry| async move { @@ -348,9 +346,7 @@ impl NamespaceStore { } }; - dbg!(); let handle = self.inner.metadata.handle(namespace.to_owned()); - dbg!(); f(self .load_namespace(&namespace, handle, RestoreOption::Latest) .await?) @@ -377,7 +373,6 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - dbg!(); let ns = self .get_configurator(&config.get()) .setup( @@ -391,7 +386,6 @@ impl NamespaceStore { ) .await?; - dbg!(); Ok(ns) } @@ -401,17 +395,13 @@ impl NamespaceStore { db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - dbg!(); let init = async { - dbg!(); let ns = self .make_namespace(namespace, db_config, restore_option) .await?; - dbg!(); Ok(Some(ns)) }; - dbg!(); let before_load = Instant::now(); let ns = self .inner @@ -420,8 +410,7 @@ impl NamespaceStore { namespace.clone(), init.map_ok(|ns| Arc::new(RwLock::new(ns))), ) - .await.map_err(|e| dbg!(e))?; - dbg!(); + .await?; NAMESPACE_LOAD_LATENCY.record(before_load.elapsed()); Ok(ns) diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index a8195cbbd0..57916bb9a5 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -830,16 +830,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -961,16 +955,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1044,16 +1032,10 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary, - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); store .with("ns".into(), |ns| { @@ -1078,9 +1060,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1151,16 +1134,10 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 8f214bd05e..1171d4a5d0 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -149,29 +149,23 @@ fn sync_many_replica() { let mut sim = Builder::new() .simulation_duration(Duration::from_secs(1000)) .build(); - dbg!(); make_cluster(&mut sim, NUM_REPLICA, true); - dbg!(); sim.client("client", async { let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; - dbg!(); conn.execute("create table test (x)", ()).await?; - dbg!(); conn.execute("insert into test values (42)", ()).await?; - dbg!(); async fn get_frame_no(url: &str) -> Option { let client = Client::new(); - dbg!(); Some( - dbg!(client - .get(url) - .await - .unwrap() - .json::() - .await) + client + .get(url) + .await + .unwrap() + .json::() + .await .unwrap() .get("replication_index")? .as_u64() @@ -179,7 +173,6 @@ fn sync_many_replica() { ) } - dbg!(); let primary_fno = loop { if let Some(fno) = get_frame_no("http://primary:9090/v1/namespaces/default/stats").await { @@ -187,15 +180,13 @@ fn sync_many_replica() { } }; - dbg!(); // wait for all replicas to sync let mut join_set = JoinSet::new(); for i in 0..NUM_REPLICA { join_set.spawn(async move { let uri = format!("http://replica{i}:9090/v1/namespaces/default/stats"); - dbg!(); loop { - if let Some(replica_fno) = dbg!(get_frame_no(&uri).await) { + if let Some(replica_fno) = get_frame_no(&uri).await { if replica_fno == primary_fno { break; } @@ -205,10 +196,8 @@ fn sync_many_replica() { }); } - dbg!(); while join_set.join_next().await.is_some() {} - dbg!(); for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), @@ -223,10 +212,8 @@ fn sync_many_replica() { )); } - dbg!(); let client = Client::new(); - dbg!(); let stats = client .get("http://primary:9090/v1/namespaces/default/stats") .await? @@ -234,14 +221,12 @@ fn sync_many_replica() { .await .unwrap(); - dbg!(); let stat = stats .get("embedded_replica_frames_replicated") .unwrap() .as_u64() .unwrap(); - dbg!(); assert_eq!(stat, 0); Ok(()) From e5b8c31005069982de27ad0b58b109929f45bf4b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 17:01:34 +0200 Subject: [PATCH 017/121] comment out libsql-wal replica configurator --- libsql-server/src/lib.rs | 34 +++--- .../configurator/libsql_wal_replica.rs | 115 +++++++++--------- .../src/namespace/configurator/mod.rs | 2 +- 3 files changed, 79 insertions(+), 72 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index d26921dd00..9bf0419932 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -745,21 +745,27 @@ where }); let make_wal_manager = Arc::new(move || EitherWAL::B(wal.clone())); - let mut configurators = NamespaceConfigurators::empty(); + // let mut configurators = NamespaceConfigurators::empty(); + + // match client_config { + // Some(_) => todo!("configure replica"), + // // configure primary + // None => self.configure_primary_common( + // base_config, + // &mut configurators, + // make_wal_manager, + // migration_scheduler_handle, + // scripted_backup, + // ), + // } - match client_config { - Some(_) => todo!("configure replica"), - // configure primary - None => self.configure_primary_common( - base_config, - &mut configurators, - make_wal_manager, - migration_scheduler_handle, - scripted_backup, - ), - } - - Ok(configurators) + self.configurators_common( + base_config, + client_config, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ) } #[cfg(feature = "durable-wal")] diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs index f26738ec2a..6ab6cc52ef 100644 --- a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -46,63 +46,64 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { store: NamespaceStore, broadcaster: BroadcasterHandle, ) -> Pin> + Send + 'a>> { - Box::pin(async move { - tracing::debug!("creating replica namespace"); - let db_path = self.base.base_path.join("dbs").join(name.as_str()); - let channel = self.channel.clone(); - let uri = self.uri.clone(); - - let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); - // TODO! setup replication - - let mut join_set = JoinSet::new(); - let namespace = name.clone(); - - let stats = make_stats( - &db_path, - &mut join_set, - db_config.clone(), - self.base.stats_sender.clone(), - name.clone(), - applied_frame_no_receiver.clone(), - ) - .await?; - - let connection_maker = MakeWriteProxyConn::new( - db_path.clone(), - self.base.extensions.clone(), - channel.clone(), - uri.clone(), - stats.clone(), - broadcaster, - db_config.clone(), - applied_frame_no_receiver, - self.base.max_response_size, - self.base.max_total_response_size, - primary_current_replication_index, - None, - resolve_attach_path, - self.make_wal_manager.clone(), - ) - .await? - .throttled( - self.base.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - self.base.max_total_response_size, - self.base.max_concurrent_requests, - ); - - Ok(Namespace { - tasks: join_set, - db: Database::Replica(ReplicaDatabase { - connection_maker: Arc::new(connection_maker), - }), - name: name.clone(), - stats, - db_config_store: db_config, - path: db_path.into(), - }) - }) + todo!() + // Box::pin(async move { + // tracing::debug!("creating replica namespace"); + // let db_path = self.base.base_path.join("dbs").join(name.as_str()); + // let channel = self.channel.clone(); + // let uri = self.uri.clone(); + // + // let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + // // TODO! setup replication + // + // let mut join_set = JoinSet::new(); + // let namespace = name.clone(); + // + // let stats = make_stats( + // &db_path, + // &mut join_set, + // db_config.clone(), + // self.base.stats_sender.clone(), + // name.clone(), + // applied_frame_no_receiver.clone(), + // ) + // .await?; + // + // let connection_maker = MakeWriteProxyConn::new( + // db_path.clone(), + // self.base.extensions.clone(), + // channel.clone(), + // uri.clone(), + // stats.clone(), + // broadcaster, + // db_config.clone(), + // applied_frame_no_receiver, + // self.base.max_response_size, + // self.base.max_total_response_size, + // primary_current_replication_index, + // None, + // resolve_attach_path, + // self.make_wal_manager.clone(), + // ) + // .await? + // .throttled( + // self.base.max_concurrent_connections.clone(), + // Some(DB_CREATE_TIMEOUT), + // self.base.max_total_response_size, + // self.base.max_concurrent_requests, + // ); + // + // Ok(Namespace { + // tasks: join_set, + // db: Database::Replica(ReplicaDatabase { + // connection_maker: Arc::new(connection_maker), + // }), + // name: name.clone(), + // stats, + // db_config_store: db_config, + // path: db_path.into(), + // }) + // }) } fn cleanup<'a>( diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 9122fc18de..0f8dcbd481 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -20,7 +20,7 @@ use super::{ pub mod fork; mod helpers; -mod libsql_wal_replica; +// mod libsql_wal_replica; mod primary; mod replica; mod schema; From 6e7fb9f06a901fe20a340cf9e54b1523c44f8eb5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 18:24:20 +0200 Subject: [PATCH 018/121] restore encryption config we don't actually care, but let's do it for completeness --- libsql-server/src/lib.rs | 1 + .../src/namespace/configurator/helpers.rs | 23 +++++++++++++++---- .../src/namespace/configurator/mod.rs | 2 ++ .../src/namespace/configurator/primary.rs | 8 ++++++- .../src/namespace/configurator/replica.rs | 1 + .../src/namespace/configurator/schema.rs | 1 + libsql-server/src/schema/scheduler.rs | 1 + 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 9bf0419932..4b97b442f5 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -458,6 +458,7 @@ where max_total_response_size: self.db_config.max_total_response_size, max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), max_concurrent_requests: self.db_config.max_concurrent_requests, + encryption_config: self.db_config.encryption_config.clone(), }; let configurators = self diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs index a5a4c5121d..355b1b1472 100644 --- a/libsql-server/src/namespace/configurator/helpers.rs +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use enclose::enclose; use futures::Stream; use libsql_sys::wal::Sqlite3WalManager; +use libsql_sys::EncryptionConfig; use tokio::io::AsyncBufReadExt as _; use tokio::sync::watch; use tokio::task::JoinSet; @@ -49,6 +50,7 @@ pub(super) async fn make_primary_connection_maker( resolve_attach_path: ResolveNamespacePathFn, broadcaster: BroadcasterHandle, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + encryption_config: Option, ) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { let db_config = meta_store_handle.get(); let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); @@ -102,7 +104,7 @@ pub(super) async fn make_primary_connection_maker( auto_checkpoint, primary_config.scripted_backup.clone(), name.clone(), - None, + encryption_config.clone(), )?); tracing::debug!("sending stats"); @@ -114,6 +116,7 @@ pub(super) async fn make_primary_connection_maker( base_config.stats_sender.clone(), name.clone(), logger.new_frame_notifier.subscribe(), + base_config.encryption_config.clone(), ) .await?; @@ -133,7 +136,7 @@ pub(super) async fn make_primary_connection_maker( base_config.max_total_response_size, auto_checkpoint, logger.new_frame_notifier.subscribe(), - None, + encryption_config, block_writes, resolve_attach_path, make_wal_manager.clone(), @@ -332,6 +335,7 @@ pub(super) async fn make_stats( stats_sender: StatsSender, name: NamespaceName, mut current_frame_no: watch::Receiver>, + encryption_config: Option, ) -> anyhow::Result> { tracing::debug!("creating stats type"); let stats = Stats::new(name.clone(), db_path, join_set).await?; @@ -358,7 +362,11 @@ pub(super) async fn make_stats( } }); - join_set.spawn(run_storage_monitor(db_path.into(), Arc::downgrade(&stats))); + join_set.spawn(run_storage_monitor( + db_path.into(), + Arc::downgrade(&stats), + encryption_config, + )); tracing::debug!("done sending stats, and creating bg tasks"); @@ -368,7 +376,11 @@ pub(super) async fn make_stats( // Periodically check the storage used by the database and save it in the Stats structure. // TODO: Once we have a separate fiber that does WAL checkpoints, running this routine // right after checkpointing is exactly where it should be done. -async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Result<()> { +async fn run_storage_monitor( + db_path: PathBuf, + stats: Weak, + encryption_config: Option, +) -> anyhow::Result<()> { // on initialization, the database file doesn't exist yet, so we wait a bit for it to be // created tokio::time::sleep(Duration::from_secs(1)).await; @@ -381,11 +393,12 @@ async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Re return Ok(()); }; + let encryption_config = encryption_config.clone(); let _ = tokio::task::spawn_blocking(move || { // because closing the last connection interferes with opening a new one, we lazily // initialize a connection here, and keep it alive for the entirety of the program. If we // fail to open it, we wait for `duration` and try again later. - match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), None) { + match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), encryption_config) { Ok(mut conn) => { if let Ok(tx) = conn.transaction() { let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 0f8dcbd481..b96d5a3824 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -5,6 +5,7 @@ use std::time::Duration; use chrono::NaiveDateTime; use futures::Future; +use libsql_sys::EncryptionConfig; use tokio::sync::Semaphore; use crate::connection::config::DatabaseConfig; @@ -38,6 +39,7 @@ pub struct BaseNamespaceConfig { pub(crate) max_total_response_size: u64, pub(crate) max_concurrent_connections: Arc, pub(crate) max_concurrent_requests: u64, + pub(crate) encryption_config: Option, } #[derive(Clone)] diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index 6c245a6e8f..03cdd2fd7b 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -1,7 +1,10 @@ +use std::path::Path; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::{path::Path, pin::Pin, sync::Arc}; +use std::sync::Arc; use futures::prelude::Future; +use libsql_sys::EncryptionConfig; use tokio::task::JoinSet; use crate::connection::config::DatabaseConfig; @@ -49,6 +52,7 @@ impl PrimaryConfigurator { resolve_attach_path: ResolveNamespacePathFn, db_path: Arc, broadcaster: BroadcasterHandle, + encryption_config: Option, ) -> crate::Result { let mut join_set = JoinSet::new(); @@ -67,6 +71,7 @@ impl PrimaryConfigurator { resolve_attach_path, broadcaster, self.make_wal_manager.clone(), + encryption_config, ) .await?; let connection_maker = Arc::new(connection_maker); @@ -135,6 +140,7 @@ impl ConfigureNamespace for PrimaryConfigurator { resolve_attach_path, db_path.clone(), broadcaster, + self.base.encryption_config.clone(), ) .await { diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 61dd48b0bf..84ebadb897 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -169,6 +169,7 @@ impl ConfigureNamespace for ReplicaConfigurator { self.base.stats_sender.clone(), name.clone(), applied_frame_no_receiver.clone(), + self.base.encryption_config.clone(), ) .await?; diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index 98e679513a..f95c8abf51 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -68,6 +68,7 @@ impl ConfigureNamespace for SchemaConfigurator { resolve_attach_path, broadcaster, self.make_wal_manager.clone(), + self.base.encryption_config.clone(), ) .await?; diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 57916bb9a5..01a3d795d8 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -917,6 +917,7 @@ mod test { max_total_response_size: 100000000000, max_concurrent_connections: Arc::new(Semaphore::new(10)), max_concurrent_requests: 10000, + encryption_config: None, }; let primary_config = PrimaryExtraConfig { From 71c50e198fe0e61880cd2d5917af351c34a9f21e Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 6 Aug 2024 08:54:23 -0400 Subject: [PATCH 019/121] enable more windows CI --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 26eaba46cf..1903c0baff 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -159,8 +159,8 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo- - - name: check libsql remote - run: cargo check -p libsql --no-default-features -F remote + - name: build libsql all features + run: cargo build -p libsql --all-features # test-rust-wasm: # runs-on: ubuntu-latest From 07dc9b5b6d36e6eeae188c9a330df34bb99337ce Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:04:09 +0200 Subject: [PATCH 020/121] add LibsqlWalFooter --- libsql-wal/src/lib.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/libsql-wal/src/lib.rs b/libsql-wal/src/lib.rs index df104eda49..d46ade0010 100644 --- a/libsql-wal/src/lib.rs +++ b/libsql-wal/src/lib.rs @@ -15,6 +15,22 @@ const LIBSQL_MAGIC: u64 = u64::from_be_bytes(*b"LIBSQL\0\0"); const LIBSQL_PAGE_SIZE: u16 = 4096; const LIBSQL_WAL_VERSION: u16 = 1; +use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; +/// LibsqlFooter is located at the end of the libsql file. I contains libsql specific metadata, +/// while remaining fully compatible with sqlite (which just ignores that footer) +/// +/// The fields are in big endian to remain coherent with sqlite +#[derive(Copy, Clone, Debug, zerocopy::FromBytes, zerocopy::FromZeroes, zerocopy::AsBytes)] +#[repr(C)] +pub struct LibsqlFooter { + magic: bu64, + version: bu16, + /// Replication index checkpointed into this file. + /// only valid if there are no outstanding segments to checkpoint, since a checkpoint could be + /// partial. + replication_index: bu64, +} + #[cfg(any(debug_assertions, test))] pub mod test { use std::fs::OpenOptions; From 4069036362f41cc9015dab4a5698f6790e57e1f1 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:48:40 +0200 Subject: [PATCH 021/121] cancel query when request is dropped --- libsql-server/src/connection/libsql.rs | 49 ++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index aadff6190b..1acbbf2588 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -391,14 +391,43 @@ where ctx: RequestContext, builder: B, ) -> Result<(B, Program)> { + struct Bomb { + canceled: Arc, + defused: bool, + } + + impl Drop for Bomb { + fn drop(&mut self) { + if !self.defused { + tracing::debug!("cancelling request"); + self.canceled.store(true, Ordering::Relaxed); + } + } + } + + let canceled = { + let cancelled = self.inner.lock().canceled.clone(); + cancelled.store(false, Ordering::Relaxed); + cancelled + }; + + let mut bomb = Bomb { + canceled, + defused: false, + }; + PROGRAM_EXEC_COUNT.increment(1); check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; let conn = self.inner.clone(); - BLOCKING_RT + let ret = BLOCKING_RT .spawn_blocking(move || Connection::run(conn, pgm, builder)) .await - .unwrap() + .unwrap(); + + bomb.defused = true; + + ret } } @@ -413,6 +442,7 @@ pub(super) struct Connection { forced_rollback: bool, broadcaster: BroadcasterHandle, hooked: bool, + canceled: Arc, } fn update_stats( @@ -475,6 +505,19 @@ impl Connection { ); } + let canceled = Arc::new(AtomicBool::new(false)); + + conn.progress_handler(100, { + let canceled = canceled.clone(); + Some(move || { + let canceled = canceled.load(Ordering::Relaxed); + if canceled { + tracing::debug!("request canceled"); + } + canceled + }) + }); + let this = Self { conn, stats, @@ -486,6 +529,7 @@ impl Connection { forced_rollback: false, broadcaster, hooked: false, + canceled, }; for ext in extensions.iter() { @@ -795,6 +839,7 @@ mod test { forced_rollback: false, broadcaster: Default::default(), hooked: false, + canceled: Arc::new(false.into()), }; let conn = Arc::new(Mutex::new(conn)); From 5924766712c783136b634ff325238a0fb1858cae Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:15:12 +0200 Subject: [PATCH 022/121] write footer on checkpoint --- libsql-wal/src/lib.rs | 8 ++++---- libsql-wal/src/segment/list.rs | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/libsql-wal/src/lib.rs b/libsql-wal/src/lib.rs index d46ade0010..1c0dc63566 100644 --- a/libsql-wal/src/lib.rs +++ b/libsql-wal/src/lib.rs @@ -15,7 +15,7 @@ const LIBSQL_MAGIC: u64 = u64::from_be_bytes(*b"LIBSQL\0\0"); const LIBSQL_PAGE_SIZE: u16 = 4096; const LIBSQL_WAL_VERSION: u16 = 1; -use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; +use zerocopy::byteorder::big_endian::{U16 as bu16, U64 as bu64}; /// LibsqlFooter is located at the end of the libsql file. I contains libsql specific metadata, /// while remaining fully compatible with sqlite (which just ignores that footer) /// @@ -23,12 +23,12 @@ use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; #[derive(Copy, Clone, Debug, zerocopy::FromBytes, zerocopy::FromZeroes, zerocopy::AsBytes)] #[repr(C)] pub struct LibsqlFooter { - magic: bu64, - version: bu16, + pub magic: bu64, + pub version: bu16, /// Replication index checkpointed into this file. /// only valid if there are no outstanding segments to checkpoint, since a checkpoint could be /// partial. - replication_index: bu64, + pub replication_index: bu64, } #[cfg(any(debug_assertions, test))] diff --git a/libsql-wal/src/segment/list.rs b/libsql-wal/src/segment/list.rs index 25dfa3a32a..f1e3252161 100644 --- a/libsql-wal/src/segment/list.rs +++ b/libsql-wal/src/segment/list.rs @@ -15,6 +15,7 @@ use crate::error::Result; use crate::io::buf::{ZeroCopyBoxIoBuf, ZeroCopyBuf}; use crate::io::FileExt; use crate::segment::Frame; +use crate::{LibsqlFooter, LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION}; use super::Segment; @@ -157,6 +158,21 @@ where buf = read_buf.into_inner(); } + // update the footer at the end of the db file. + let footer = LibsqlFooter { + magic: LIBSQL_MAGIC.into(), + version: LIBSQL_WAL_VERSION.into(), + replication_index: last_replication_index.into(), + }; + + let footer_offset = size_after as usize * LIBSQL_PAGE_SIZE as usize; + let (_, ret) = db_file + .write_all_at_async(ZeroCopyBuf::new_init(footer), footer_offset as u64) + .await; + ret?; + + // todo: truncate if necessary + //// todo: make async db_file.sync_all()?; @@ -185,7 +201,7 @@ where Ok(Some(last_replication_index)) } - /// returnsstream pages from the sealed segment list, and what's the lowest replication index + /// returns a stream of pages from the sealed segment list, and what's the lowest replication index /// that was covered. If the returned index is less than start frame_no, the missing frames /// must be read somewhere else. pub async fn stream_pages_from<'a>( From d11ec010df9a049a17eb201785abaa6d8219f9d5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 15:38:56 +0200 Subject: [PATCH 023/121] downgrade debug to trace --- libsql-server/src/connection/libsql.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index 1acbbf2588..d98d6d0f82 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -399,7 +399,7 @@ where impl Drop for Bomb { fn drop(&mut self) { if !self.defused { - tracing::debug!("cancelling request"); + tracing::trace!("cancelling request"); self.canceled.store(true, Ordering::Relaxed); } } @@ -512,7 +512,7 @@ impl Connection { Some(move || { let canceled = canceled.load(Ordering::Relaxed); if canceled { - tracing::debug!("request canceled"); + tracing::trace!("request canceled"); } canceled }) From fc178de41fe56fd3ba845a9dc09bf71b6bf1b2d7 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 15:42:01 +0200 Subject: [PATCH 024/121] add query canceled metric --- libsql-server/src/connection/libsql.rs | 5 ++++- libsql-server/src/metrics.rs | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index d98d6d0f82..aa0604f03a 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -15,7 +15,9 @@ use tokio::sync::watch; use tokio::time::{Duration, Instant}; use crate::error::Error; -use crate::metrics::{DESCRIBE_COUNT, PROGRAM_EXEC_COUNT, VACUUM_COUNT, WAL_CHECKPOINT_COUNT}; +use crate::metrics::{ + DESCRIBE_COUNT, PROGRAM_EXEC_COUNT, QUERY_CANCELED, VACUUM_COUNT, WAL_CHECKPOINT_COUNT, +}; use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::ResolveNamespacePathFn; @@ -512,6 +514,7 @@ impl Connection { Some(move || { let canceled = canceled.load(Ordering::Relaxed); if canceled { + QUERY_CANCELED.increment(1); tracing::trace!("request canceled"); } canceled diff --git a/libsql-server/src/metrics.rs b/libsql-server/src/metrics.rs index a71b5ca979..1ac97435b3 100644 --- a/libsql-server/src/metrics.rs +++ b/libsql-server/src/metrics.rs @@ -153,3 +153,8 @@ pub static LISTEN_EVENTS_DROPPED: Lazy = Lazy::new(|| { describe_counter!(NAME, "Number of listen events dropped"); register_counter!(NAME) }); +pub static QUERY_CANCELED: Lazy = Lazy::new(|| { + const NAME: &str = "libsql_server_query_canceled"; + describe_counter!(NAME, "Number of canceled queries"); + register_counter!(NAME) +}); From 351e6ebfbec97e0a2c12f2d056f2876cb0058e38 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 18:52:57 +0400 Subject: [PATCH 025/121] add simple integration test --- libsql/tests/integration_tests.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 0f8e575949..cc239888d0 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -596,6 +596,22 @@ async fn debug_print_row() { ); } +#[tokio::test] +async fn fts5_invalid_tokenizer() { + let db = Database::open(":memory:").unwrap(); + let conn = db.connect().unwrap(); + assert!(conn.execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", + (), + ) + .await.is_err()); + assert!(conn.execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", + (), + ) + .await.is_err()); +} + #[cfg(feature = "serde")] #[tokio::test] async fn deserialize_row() { From 3e56d28d8614a070dd632c6d54b0cdd1d2e08579 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 16:57:49 +0400 Subject: [PATCH 026/121] fix potential crash in fts5 - see: https://sqlite.org/forum/forumpost/171bcc2bcd --- libsql-sqlite3/ext/fts5/fts5_tokenize.c | 60 ++++++++++++++----------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/libsql-sqlite3/ext/fts5/fts5_tokenize.c b/libsql-sqlite3/ext/fts5/fts5_tokenize.c index f12056170f..7e239b6ca5 100644 --- a/libsql-sqlite3/ext/fts5/fts5_tokenize.c +++ b/libsql-sqlite3/ext/fts5/fts5_tokenize.c @@ -1290,40 +1290,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; From 7ed14683a177ad913d63f06d5d6848846c6a0d00 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 18:53:13 +0400 Subject: [PATCH 027/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 61 +++++++++++-------- libsql-ffi/bundled/bindings/bindgen.rs | 16 +++-- libsql-ffi/bundled/src/sqlite3.c | 61 +++++++++++-------- 3 files changed, 80 insertions(+), 58 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 529af0d52e..d7587cc38b 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -28,6 +28,7 @@ ** README.md ** configure ** configure.ac +** ext/fts5/fts5_tokenize.c ** ext/jni/src/org/sqlite/jni/capi/CollationNeededCallback.java ** ext/jni/src/org/sqlite/jni/capi/CommitHookCallback.java ** ext/jni/src/org/sqlite/jni/capi/PreupdateHookCallback.java @@ -259750,40 +259751,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; diff --git a/libsql-ffi/bundled/bindings/bindgen.rs b/libsql-ffi/bundled/bindings/bindgen.rs index 9dec505c10..cc73807f33 100644 --- a/libsql-ffi/bundled/bindings/bindgen.rs +++ b/libsql-ffi/bundled/bindings/bindgen.rs @@ -940,7 +940,7 @@ extern "C" { extern "C" { pub fn sqlite3_vmprintf( arg1: *const ::std::os::raw::c_char, - arg2: va_list, + arg2: *mut __va_list_tag, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -956,7 +956,7 @@ extern "C" { arg1: ::std::os::raw::c_int, arg2: *mut ::std::os::raw::c_char, arg3: *const ::std::os::raw::c_char, - arg4: va_list, + arg4: *mut __va_list_tag, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -2503,7 +2503,7 @@ extern "C" { pub fn sqlite3_str_vappendf( arg1: *mut sqlite3_str, zFormat: *const ::std::os::raw::c_char, - arg2: va_list, + arg2: *mut __va_list_tag, ); } extern "C" { @@ -3524,4 +3524,12 @@ extern "C" { extern "C" { pub static sqlite3_wal_manager: libsql_wal_manager; } -pub type __builtin_va_list = *mut ::std::os::raw::c_char; +pub type __builtin_va_list = [__va_list_tag; 1usize]; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __va_list_tag { + pub gp_offset: ::std::os::raw::c_uint, + pub fp_offset: ::std::os::raw::c_uint, + pub overflow_arg_area: *mut ::std::os::raw::c_void, + pub reg_save_area: *mut ::std::os::raw::c_void, +} diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 529af0d52e..d7587cc38b 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -28,6 +28,7 @@ ** README.md ** configure ** configure.ac +** ext/fts5/fts5_tokenize.c ** ext/jni/src/org/sqlite/jni/capi/CollationNeededCallback.java ** ext/jni/src/org/sqlite/jni/capi/CommitHookCallback.java ** ext/jni/src/org/sqlite/jni/capi/PreupdateHookCallback.java @@ -259750,40 +259751,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; From 9595315d30ee56d9845e884fd3fae768352967b9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 20:42:27 +0200 Subject: [PATCH 028/121] init cancel bomb berfore query exec --- libsql-server/src/connection/libsql.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index aa0604f03a..9896164e55 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -413,14 +413,15 @@ where cancelled }; + PROGRAM_EXEC_COUNT.increment(1); + + check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; + + // create the bomb right before spawning the blocking task. let mut bomb = Bomb { canceled, defused: false, }; - - PROGRAM_EXEC_COUNT.increment(1); - - check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; let conn = self.inner.clone(); let ret = BLOCKING_RT .spawn_blocking(move || Connection::run(conn, pgm, builder)) From 0d411057ae5d42bffd1e451bfc2e5aa44ab72042 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 00:33:27 +0400 Subject: [PATCH 029/121] cargo fmt --- libsql/tests/integration_tests.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index cc239888d0..cdb0a985c3 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -600,16 +600,20 @@ async fn debug_print_row() { async fn fts5_invalid_tokenizer() { let db = Database::open(":memory:").unwrap(); let conn = db.connect().unwrap(); - assert!(conn.execute( - "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", - (), - ) - .await.is_err()); - assert!(conn.execute( - "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", - (), - ) - .await.is_err()); + assert!(conn + .execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", + (), + ) + .await + .is_err()); + assert!(conn + .execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", + (), + ) + .await + .is_err()); } #[cfg(feature = "serde")] From 4085a0d35edd59c0c75bb917f4236dcdeeb0f574 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Wed, 7 Aug 2024 17:40:39 -0400 Subject: [PATCH 030/121] libsql: downgrade failed prefetch log to debug --- libsql/src/replication/remote_client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index dbab056938..d0052f50d9 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -135,7 +135,7 @@ impl RemoteClient { (hello_fut.await, None) }; self.prefetched_batch_log_entries = if let Ok(true) = hello.0 { - tracing::warn!( + tracing::debug!( "Frames prefetching failed because of new session token returned by handshake" ); None From b0bc6eb2f5686b2e1706dd06b51405e8f0257ecd Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 12:13:13 +0400 Subject: [PATCH 031/121] publish sqld debug builds to the separate image name --- .github/workflows/publish-server.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-server.yml b/.github/workflows/publish-server.yml index 10820457b8..d957195e40 100644 --- a/.github/workflows/publish-server.yml +++ b/.github/workflows/publish-server.yml @@ -118,7 +118,7 @@ jobs: context: . platforms: ${{ env.platform }} labels: ${{ steps.meta.outputs.labels }} - outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-debug,push-by-digest=true,name-canonical=true,push=true build-args: | BUILD_DEBUG=true - From 51b1b490545524e846e52479743054fbbe2e1660 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:23:09 +0400 Subject: [PATCH 032/121] remove digests artifacts from debug build step --- .github/workflows/publish-server.yml | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/.github/workflows/publish-server.yml b/.github/workflows/publish-server.yml index d957195e40..e1973fe47c 100644 --- a/.github/workflows/publish-server.yml +++ b/.github/workflows/publish-server.yml @@ -121,20 +121,6 @@ jobs: outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-debug,push-by-digest=true,name-canonical=true,push=true build-args: | BUILD_DEBUG=true - - - name: Export digest - run: | - mkdir -p /tmp/digests - digest="${{ steps.build.outputs.digest }}" - touch "/tmp/digests/${digest#sha256:}" - - - name: Upload digest - uses: actions/upload-artifact@v4 - with: - name: digests-debug-${{ env.PLATFORM_PAIR }} - path: /tmp/digests/* - if-no-files-found: error - retention-days: 1 build-arm64: permissions: write-all From ec7bca5a20eb413ba658eef04e643d94f6fa562a Mon Sep 17 00:00:00 2001 From: wyhaya Date: Thu, 8 Aug 2024 12:24:36 +0800 Subject: [PATCH 033/121] Fix JSON f64 precision --- libsql/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index efae2abea3..3d65f71c73 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -20,7 +20,7 @@ hyper = { workspace = true, features = ["client", "stream"], optional = true } hyper-rustls = { version = "0.25", features = ["webpki-roots"], optional = true } base64 = { version = "0.21", optional = true } serde = { version = "1", features = ["derive"], optional = true } -serde_json = { version = "1", optional = true } +serde_json = { version = "1", features = ["float_roundtrip"], optional = true } async-trait = "0.1" bitflags = { version = "2.4.0", optional = true } tower = { workspace = true, features = ["util"], optional = true } From 5eeba4330901f022963bf50cf35e3e7120cf1a09 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:18:51 +0400 Subject: [PATCH 034/121] improve random row selection --- libsql-sqlite3/src/vectordiskann.c | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95d473b630..8804aee119 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -442,6 +442,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -509,6 +510,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -518,9 +520,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } From 4b3e7e7544e68a03ec04f38a686791bb76886fd3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:20:41 +0400 Subject: [PATCH 035/121] fix random row selection query to have db name --- libsql-sqlite3/src/vectordiskann.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 8804aee119..fc39e00d30 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -574,8 +574,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; From 0b41b5aa7c93639218caf6f765d6c4d26e0b0567 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:33:31 +0400 Subject: [PATCH 036/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 28 +++++++++++++++++-- libsql-ffi/bundled/src/sqlite3.c | 28 +++++++++++++++++-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index d7587cc38b..15d09606fb 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212002,6 +212002,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -212069,6 +212070,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -212078,9 +212080,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } @@ -212110,8 +212134,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index d7587cc38b..15d09606fb 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212002,6 +212002,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -212069,6 +212070,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -212078,9 +212080,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } @@ -212110,8 +212134,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; From f80444ac450d19754bf11de98753be6f43ca8332 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:50:29 +0400 Subject: [PATCH 037/121] fix test --- libsql-sqlite3/test/libsql_vector_index.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7308b2d93f..c1a270e4da 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -140,7 +140,7 @@ do_execsql_test vector-sql { INSERT INTO t_sql VALUES(vector('[1,2,3]')), (vector('[2,3,4]')); SELECT sql FROM sqlite_master WHERE name LIKE '%t_sql%'; SELECT name FROM libsql_vector_meta_shadow WHERE name = 't_sql_idx'; -} {{CREATE TABLE t_sql( v FLOAT32(3))} {CREATE TABLE t_sql_idx_shadow (index_key INTEGER , data BLOB, PRIMARY KEY (index_key))} {CREATE INDEX t_sql_idx ON t_sql( libsql_vector_idx(v) )} {t_sql_idx}} +} {{CREATE TABLE t_sql( v FLOAT32(3))} {CREATE TABLE t_sql_idx_shadow (index_key INTEGER , data BLOB, PRIMARY KEY (index_key))} {CREATE INDEX t_sql_idx_shadow_idx ON t_sql_idx_shadow (index_key)} {CREATE INDEX t_sql_idx ON t_sql( libsql_vector_idx(v) )} {t_sql_idx}} do_execsql_test vector-drop-index { CREATE TABLE t_index_drop( v FLOAT32(3)); From 83d029d7e922618c845f85db5e8a2d09d7f3b188 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 5 Aug 2024 17:13:19 +0400 Subject: [PATCH 038/121] cleanup code a bit in order to simplify working with vector of different types --- libsql-sqlite3/src/vector.c | 69 ++++++++++++++++++++++-------- libsql-sqlite3/src/vectorInt.h | 16 ++----- libsql-sqlite3/src/vectordiskann.c | 16 +++++-- libsql-sqlite3/src/vectorfloat32.c | 38 ++-------------- libsql-sqlite3/src/vectorfloat64.c | 69 ++---------------------------- 5 files changed, 76 insertions(+), 132 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index d32819cd00..bef34140bf 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -252,11 +252,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -384,20 +402,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -412,18 +457,6 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return 0; } -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ switch (pVector->type) { case VECTOR_TYPE_FLOAT32: @@ -470,7 +503,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 8c9138b94f..64703b447f 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -65,13 +65,6 @@ size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); - /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) */ @@ -91,16 +84,15 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); + +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95d473b630..54f3fba0b3 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -47,6 +47,7 @@ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "math.h" @@ -1490,6 +1491,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -1536,9 +1538,17 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = VECTOR_TYPE_1BIT; + pIndex->nEdgeVectorSize = vectorDataSize(VECTOR_TYPE_1BIT, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; return SQLITE_OK; diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 8aeae2eb23..5d6641991c 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -94,26 +94,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -220,32 +200,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index ced2be1843..1d29c9c3d6 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -98,57 +98,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -227,32 +176,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ From 39e30ead5e755c25971c8d05731e5778bbfaece8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 5 Aug 2024 18:14:58 +0400 Subject: [PATCH 039/121] add 1bit vector type --- libsql-sqlite3/Makefile.in | 6 +- libsql-sqlite3/src/vector.c | 9 +++ libsql-sqlite3/src/vector1bit.c | 110 ++++++++++++++++++++++++++++ libsql-sqlite3/src/vectorIndex.c | 25 ++++--- libsql-sqlite3/src/vectorIndexInt.h | 42 ++++++----- libsql-sqlite3/src/vectorInt.h | 18 +++-- libsql-sqlite3/src/vectordiskann.c | 21 ++++-- libsql-sqlite3/tool/mksqlite3c.tcl | 1 + 8 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 libsql-sqlite3/src/vector1bit.c diff --git a/libsql-sqlite3/Makefile.in b/libsql-sqlite3/Makefile.in index 4520dda0d2..0afadd458f 100644 --- a/libsql-sqlite3/Makefile.in +++ b/libsql-sqlite3/Makefile.in @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \ sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \ table.lo threads.lo tokenize.lo treeview.lo trigger.lo \ update.lo userauth.lo upsert.lo util.lo vacuum.lo \ - vector.lo vectorfloat32.lo vectorfloat64.lo \ + vector.lo vectorfloat32.lo vectorfloat64.lo vector1bit.lo \ vectorIndex.lo vectordiskann.lo vectorvtab.lo \ vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \ vdbetrace.lo vdbevtab.lo \ @@ -302,6 +302,7 @@ SRC = \ $(TOP)/src/util.c \ $(TOP)/src/vacuum.c \ $(TOP)/src/vector.c \ + $(TOP)/src/vector1bit.c \ $(TOP)/src/vectorInt.h \ $(TOP)/src/vectorfloat32.c \ $(TOP)/src/vectorfloat64.c \ @@ -1138,6 +1139,9 @@ vacuum.lo: $(TOP)/src/vacuum.c $(HDR) vector.lo: $(TOP)/src/vector.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector.c +vector1bit.lo: $(TOP)/src/vector1bit.c $(HDR) + $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector1bit.c + vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index bef34140bf..9d37dbb4a8 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -41,6 +41,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -111,6 +113,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -381,6 +385,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -451,6 +458,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c new file mode 100644 index 0000000000..c8da5496d2 --- /dev/null +++ b/libsql-sqlite3/src/vector1bit.c @@ -0,0 +1,110 @@ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +#include "sqliteInt.h" + +#include "vectorInt.h" + +#include + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + float *elems = pVector->data; + unsigned char *pPtr = pBlob; + size_t len = 0; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < pVector->dims; i++){ + elems[i] = pPtr[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int sum = 0; + u8 *e1 = v1->data; + u8 *e2 = v2->data; + int i; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < v1->dims; i++){ + sum += BitsCount[e1[i]&e2[i]]; + } + return sum; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index d8b3497781..7ad42a00ba 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -396,13 +396,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_METRIC_TYPE_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -802,7 +803,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -914,9 +915,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); diff --git a/libsql-sqlite3/src/vectorIndexInt.h b/libsql-sqlite3/src/vectorIndexInt.h index 8f73091bb1..bb6c085e94 100644 --- a/libsql-sqlite3/src/vectorIndexInt.h +++ b/libsql-sqlite3/src/vectorIndexInt.h @@ -100,43 +100,45 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* * 1 - initial version */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_DEFAULT 1 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 + +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -218,7 +220,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 64703b447f..d6f9f36635 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -24,6 +24,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -48,8 +49,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -61,9 +63,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -72,6 +75,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 54f3fba0b3..cce2f090ff 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -437,10 +437,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -477,11 +478,19 @@ int diskAnnCreateIndex( if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ return SQLITE_ERROR; } } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -1544,8 +1553,8 @@ int diskAnnOpenIndex( pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ - pIndex->nEdgeVectorType = VECTOR_TYPE_1BIT; - pIndex->nEdgeVectorSize = vectorDataSize(VECTOR_TYPE_1BIT, pIndex->nVectorDims); + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); }else{ return SQLITE_ERROR; } diff --git a/libsql-sqlite3/tool/mksqlite3c.tcl b/libsql-sqlite3/tool/mksqlite3c.tcl index 31e6d84a57..3a04459e31 100644 --- a/libsql-sqlite3/tool/mksqlite3c.tcl +++ b/libsql-sqlite3/tool/mksqlite3c.tcl @@ -468,6 +468,7 @@ set flist { json.c vector.c + vector1bit.c vectordiskann.c vectorfloat32.c vectorfloat64.c From 2e696fe5d58614559fcf976ab7e526bebc2c1065 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 11:02:22 +0400 Subject: [PATCH 040/121] restructure search a bit in order to support compressed edges --- libsql-sqlite3/src/vectordiskann.c | 168 ++++++++++++++++++++--------- 1 file changed, 118 insertions(+), 50 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index cce2f090ff..9b302a4dda 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -98,14 +98,19 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + const Vector *pNodeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + const Vector *pEdgeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ + float *aDistances; /* array of distances to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -805,6 +810,53 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -841,16 +893,21 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + pCtx->pNodeQuery = pQuery; + pCtx->pEdgeQuery = pQuery; pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -861,6 +918,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -884,6 +947,8 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); } // check if we visited this node earlier @@ -925,7 +990,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -934,56 +1001,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; pCtx->nUnvisited++; } @@ -1131,7 +1193,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->pNodeQuery, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -1148,8 +1210,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -1177,13 +1240,18 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->pNodeQuery != pCtx->pEdgeQuery ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->pNodeQuery); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; @@ -1195,7 +1263,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pEdgeQuery, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -1272,7 +1340,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -1281,7 +1349,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -1289,9 +1357,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -1327,7 +1395,7 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; From 1a9cab9163a53f66136c6ce5aab3fa3bb7c71ae9 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 18:27:39 +0400 Subject: [PATCH 041/121] 1bit quantized embeddings search: somehow working version --- libsql-sqlite3/src/vector.c | 32 +++++- libsql-sqlite3/src/vector1bit.c | 17 ++- libsql-sqlite3/src/vectorIndex.c | 8 +- libsql-sqlite3/src/vectorInt.h | 4 +- libsql-sqlite3/src/vectordiskann.c | 166 +++++++++++++++++++++-------- 5 files changed, 166 insertions(+), 61 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 9d37dbb4a8..6c3c8d83fe 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -74,10 +74,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -479,6 +480,31 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] < 0 ){ + bitData[i / 8] &= ~(1 << (i & 7)); + }else{ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index c8da5496d2..76fc964d7c 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -56,17 +56,16 @@ size_t vector1BitSerializeToBlob( unsigned char *pBlob, size_t nBlobSize ){ - float *elems = pVector->data; - unsigned char *pPtr = pBlob; - size_t len = 0; + u8 *elems = pVector->data; + u8 *pPtr = pBlob; unsigned i; assert( pVector->type == VECTOR_TYPE_1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); - for(i = 0; i < pVector->dims; i++){ - elems[i] = pPtr[i]; + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; } return (pVector->dims + 7) / 8; } @@ -92,7 +91,7 @@ static int BitsCount[256] = { }; int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ - int sum = 0; + int diff = 0; u8 *e1 = v1->data; u8 *e2 = v2->data; int i; @@ -101,10 +100,10 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ assert( v1->type == VECTOR_TYPE_1BIT ); assert( v2->type == VECTOR_TYPE_1BIT ); - for(i = 0; i < v1->dims; i++){ - sum += BitsCount[e1[i]&e2[i]]; + for(i = 0; i < v1->dims; i += 8){ + diff += BitsCount[e1[i/8] ^ e2[i/8]]; } - return sum; + return diff; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 7ad42a00ba..5a13b4ea60 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -396,10 +396,10 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "compress_neighbors", VECTOR_METRIC_TYPE_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index d6f9f36635..84cf9c0d1f 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -102,11 +102,13 @@ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 9b302a4dda..d9bb37ba35 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -80,9 +80,18 @@ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -98,8 +107,7 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pNodeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - const Vector *pEdgeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ float *aDistances; /* array of distances to the query vector */ unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ @@ -316,7 +324,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -337,8 +345,8 @@ void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdg vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -470,19 +478,6 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); - maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); - if( maxNeighborsParam == 0 ){ - // 3 D**(1/2) gives good recall values (90%+) - // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); - } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); - if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ - return SQLITE_ERROR; - } - if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ - return SQLITE_ERROR; - } metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); if( metric == 0 ){ metric = VECTOR_METRIC_TYPE_COS; @@ -495,6 +490,23 @@ int diskAnnCreateIndex( *pzErrMsg = "1-bit compression available only for cosine metric"; return SQLITE_ERROR; } + if( neighbours == 0 ){ + neighbours = type; + } + + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); + if( maxNeighborsParam == 0 ){ + // 3 D**(1/2) gives good recall values (90%+) + // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); + } + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); + if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ + return SQLITE_ERROR; + } + if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ + return SQLITE_ERROR; + } if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ @@ -814,6 +826,36 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ ** Generic utilities **************************************************************************/ +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ int i; #ifdef SQLITE_DEBUG @@ -893,9 +935,7 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ - pCtx->pNodeQuery = pQuery; - pCtx->pEdgeQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -907,6 +947,11 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, in pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } @@ -949,6 +994,7 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ sqlite3_free(pCtx->aDistances); sqlite3_free(pCtx->aTopCandidates); sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -1075,7 +1121,13 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -1083,7 +1135,10 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; @@ -1095,8 +1150,8 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -1114,12 +1169,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -1129,7 +1186,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -1143,8 +1200,8 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -1193,7 +1250,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pNodeQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -1246,8 +1303,8 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u nEdges = nodeBinEdges(pIndex, pCandidateBlob); // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it - if( pCtx->pNodeQuery != pCtx->pEdgeQuery ){ - distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->pNodeQuery); + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); } diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); @@ -1263,7 +1320,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pEdgeQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -1340,7 +1397,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -1383,6 +1440,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -1395,12 +1455,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -1438,28 +1510,31 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -1470,6 +1545,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -1628,6 +1705,7 @@ int diskAnnOpenIndex( } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } From cd9cea34ec7944880bd4bbe3a54f39fbc5f3cc07 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 19:17:48 +0400 Subject: [PATCH 042/121] move intrinsics to the utils --- libsql-sqlite3/src/sqliteInt.h | 2 ++ libsql-sqlite3/src/util.c | 37 ++++++++++++++++++++++++++++ libsql-sqlite3/src/vector1bit.c | 39 +++++++++++------------------- libsql-sqlite3/src/vectordiskann.c | 2 +- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index e2fd32d3c4..891cf79fe4 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -5321,6 +5321,8 @@ int sqlite3AddInt64(i64*,i64); int sqlite3SubInt64(i64*,i64); int sqlite3MulInt64(i64*,i64); int sqlite3AbsInt32(int); +int sqlite3PopCount32(u32); +int sqlite3PopCount64(u64); #ifdef SQLITE_ENABLE_8_3_NAMES void sqlite3FileSuffix3(const char*, char*); #else diff --git a/libsql-sqlite3/src/util.c b/libsql-sqlite3/src/util.c index 207b901bad..0bce77e65a 100644 --- a/libsql-sqlite3/src/util.c +++ b/libsql-sqlite3/src/util.c @@ -1542,6 +1542,43 @@ int sqlite3SafetyCheckSickOrOk(sqlite3 *db){ } } + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int sqlite3PopCount64(u64 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcountll(a); +#else + return sqlite3PopCount32(a >> 32) + sqlite3PopCount32(a & 0xffffffff); +#endif +} + /* ** Attempt to add, subtract, or multiply the 64-bit signed value iB against ** the other 64-bit signed integer at *pA and store the result in *pA. diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 76fc964d7c..89b0d09d90 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -70,38 +70,27 @@ size_t vector1BitSerializeToBlob( return (pVector->dims + 7) / 8; } -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int diff = 0; - u8 *e1 = v1->data; - u8 *e2 = v2->data; - int i; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; assert( v1->dims == v2->dims ); assert( v1->type == VECTOR_TYPE_1BIT ); assert( v2->type == VECTOR_TYPE_1BIT ); - for(i = 0; i < v1->dims; i += 8){ - diff += BitsCount[e1[i/8] ^ e2[i/8]]; + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); } return diff; } diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index d9bb37ba35..c29b16e35d 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -54,7 +54,7 @@ #include "sqliteInt.h" #include "vectorIndexInt.h" -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else From eac5d90807181de75d7e8f8c0e5d3b26b0633d52 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:15:39 +0400 Subject: [PATCH 043/121] move utils back to the vector1bit to simplify inlining for compiler --- libsql-sqlite3/src/sqliteInt.h | 2 -- libsql-sqlite3/src/util.c | 37 --------------------------------- libsql-sqlite3/src/vector.c | 4 +--- libsql-sqlite3/src/vector1bit.c | 28 +++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 42 deletions(-) diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index 891cf79fe4..e2fd32d3c4 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -5321,8 +5321,6 @@ int sqlite3AddInt64(i64*,i64); int sqlite3SubInt64(i64*,i64); int sqlite3MulInt64(i64*,i64); int sqlite3AbsInt32(int); -int sqlite3PopCount32(u32); -int sqlite3PopCount64(u64); #ifdef SQLITE_ENABLE_8_3_NAMES void sqlite3FileSuffix3(const char*, char*); #else diff --git a/libsql-sqlite3/src/util.c b/libsql-sqlite3/src/util.c index 0bce77e65a..207b901bad 100644 --- a/libsql-sqlite3/src/util.c +++ b/libsql-sqlite3/src/util.c @@ -1542,43 +1542,6 @@ int sqlite3SafetyCheckSickOrOk(sqlite3 *db){ } } - -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - -int sqlite3PopCount32(u32 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcount(a); -#else - return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; -#endif -} - -int sqlite3PopCount64(u64 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcountll(a); -#else - return sqlite3PopCount32(a >> 32) + sqlite3PopCount32(a & 0xffffffff); -#endif -} - /* ** Attempt to add, subtract, or multiply the 64-bit signed value iB against ** the other 64-bit signed integer at *pA and store the result in *pA. diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 6c3c8d83fe..73b84b047e 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -494,9 +494,7 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ bitData[i / 8] = 0; } for(i = 0; i < pFrom->dims; i++){ - if( floatData[i] < 0 ){ - bitData[i / 8] &= ~(1 << (i & 7)); - }else{ + if( floatData[i] > 0 ){ bitData[i / 8] |= (1 << (i & 7)); } } diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 89b0d09d90..66da59f76a 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -70,6 +70,34 @@ size_t vector1BitSerializeToBlob( return (pVector->dims + 7) / 8; } +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int diff = 0; u8 *e1U8 = v1->data; From 0ec0147979434c4712a2f6c225bec7335f6af973 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:18:34 +0400 Subject: [PATCH 044/121] extend binary format and store distance to edges in node blocks --- libsql-sqlite3/src/vectorIndexInt.h | 10 +++-- libsql-sqlite3/src/vectordiskann.c | 62 +++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndexInt.h b/libsql-sqlite3/src/vectorIndexInt.h index bb6c085e94..e65df4d515 100644 --- a/libsql-sqlite3/src/vectorIndexInt.h +++ b/libsql-sqlite3/src/vectorIndexInt.h @@ -73,10 +73,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -102,9 +102,11 @@ typedef u8 MetricType; /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ #define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ #define VECTOR_INDEX_TYPE_PARAM_ID 2 diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index c29b16e35d..0853120797 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -129,6 +129,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -145,6 +149,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -333,13 +344,18 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( @@ -356,7 +372,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -371,7 +387,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -390,6 +406,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -424,6 +441,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -434,8 +452,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -1126,7 +1144,8 @@ static int diskAnnReplaceEdgeIdx( BlobSpot *pNodeBlob, u64 newRowid, VectorPair *pNewVector, - VectorPair *pPlaceholder + VectorPair *pPlaceholder, + float *pNodeToNew ) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; @@ -1139,19 +1158,23 @@ static int diskAnnReplaceEdgeIdx( // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); - nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -1186,7 +1209,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintEdgeVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -1194,13 +1217,16 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); @@ -1315,7 +1341,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } @@ -1512,15 +1538,16 @@ int diskAnnInsert( for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ Vector nodeVector; int iReplace; + float nodeToNew; nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); loadVectorPair(&vCandidate, &nodeVector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, vCandidate.pEdge); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } @@ -1528,12 +1555,13 @@ int diskAnnInsert( loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, vInsert.pEdge); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); @@ -1598,7 +1626,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; From 657ce07ee40e62a50945d2e927d113c993908b19 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:38:18 +0400 Subject: [PATCH 045/121] fix comment a little bit --- libsql-sqlite3/src/vectordiskann.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 0853120797..88151bf1df 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -107,9 +107,9 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - float *aDistances; /* array of distances to the query vector */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ From 11e6a9326449af84a3e940d89753ec349e8f2cd9 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:24:38 +0400 Subject: [PATCH 046/121] add simple test --- libsql-sqlite3/test/libsql_vector_index.test | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7308b2d93f..8054dfaf0b 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -268,9 +268,19 @@ do_execsql_test vector-transaction { SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2); } {3 4 1 2} +do_execsql_test vector-1bit { + CREATE TABLE t_1bit( v FLOAT32(3) ); + CREATE INDEX t_1bit_idx ON t_1bit( libsql_vector_idx(v, 'compress_neighbors=1bit') ); + INSERT INTO t_1bit VALUES (vector('[-1,-1,1]')); + INSERT INTO t_1bit VALUES (vector('[-1,1,-1.5]')); + INSERT INTO t_1bit VALUES (vector('[1,-1,-1]')); + INSERT INTO t_1bit VALUES (vector('[-1,-1,-1]')); + SELECT rowid FROM vector_top_k('t_1bit_idx', vector('[1,-1,-1]'), 4); +} {3 4 2 1} + do_execsql_test vector-all-params { CREATE TABLE t_all_params ( emb FLOAT32(2) ); - CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6')); + CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6', 'compress_neighbors=1bit')); INSERT INTO t_all_params VALUES (vector('[1,2]')), (vector('[3,4]')); SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} From 93caa2736702714e1df07152e550ca52202feadc Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:55:51 +0400 Subject: [PATCH 047/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 836 ++++++++++++------ libsql-ffi/bundled/src/sqlite3.c | 836 ++++++++++++------ 2 files changed, 1144 insertions(+), 528 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index d7587cc38b..eff02cf87a 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85246,6 +85246,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -85270,8 +85271,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -85283,16 +85285,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); - -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -85301,6 +85297,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ @@ -85313,22 +85314,23 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); + +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); @@ -85410,10 +85412,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -85437,43 +85439,47 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 + +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -85555,7 +85561,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); @@ -210981,6 +210987,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -211012,10 +211020,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -211051,6 +211060,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -211192,11 +211203,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -211303,6 +211332,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -211324,20 +211356,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -211346,18 +211405,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } @@ -211377,6 +211426,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] > 0 ){ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ @@ -211410,7 +211482,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } @@ -211557,6 +211629,135 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ +/************** Begin file vector1bit.c **************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -211607,13 +211808,14 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ /* #include "sqliteInt.h" */ /* #include "vectorIndexInt.h" */ -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else @@ -211639,9 +211841,18 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -211657,14 +211868,18 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -211675,6 +211890,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -211691,6 +211910,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -211870,7 +212096,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -211879,20 +212105,25 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -211902,7 +212133,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -211917,7 +212148,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -211936,6 +212167,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -211970,6 +212202,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -211980,8 +212213,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -211996,10 +212229,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -212023,24 +212257,36 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ + return SQLITE_ERROR; + } + } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( neighbours == 0 ){ + neighbours = type; + } + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); if( maxNeighborsParam == 0 ){ // 3 D**(1/2) gives good recall values (90%+) // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ return SQLITE_ERROR; } if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ - return SQLITE_ERROR; - } - } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -212355,6 +212601,83 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -212391,16 +212714,24 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -212411,6 +212742,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -212434,6 +212771,9 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -212475,7 +212815,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -212484,56 +212826,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; - } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); + } pCtx->nUnvisited++; } @@ -212563,7 +212900,14 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder, + float *pNodeToNew +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -212571,20 +212915,27 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -212602,12 +212953,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -212617,7 +212970,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -212625,14 +212978,17 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -212681,7 +213037,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -212698,8 +213054,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -212727,25 +213084,30 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -212822,7 +213184,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -212831,7 +213193,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -212839,9 +213201,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -212865,6 +213227,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -212877,12 +213242,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -212920,28 +213297,33 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; + float nodeToNew; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -212952,6 +213334,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -213003,7 +213387,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; @@ -213050,6 +213434,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -213096,11 +213481,20 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } @@ -213216,26 +213610,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -213342,32 +213716,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -213474,57 +213838,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -213603,32 +213916,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -214033,13 +214336,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -214439,7 +214743,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -214551,9 +214855,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index d7587cc38b..eff02cf87a 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85246,6 +85246,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -85270,8 +85271,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -85283,16 +85285,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); - -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -85301,6 +85297,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ @@ -85313,22 +85314,23 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); + +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); @@ -85410,10 +85412,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -85437,43 +85439,47 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 + +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -85555,7 +85561,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); @@ -210981,6 +210987,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -211012,10 +211020,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -211051,6 +211060,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -211192,11 +211203,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -211303,6 +211332,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -211324,20 +211356,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -211346,18 +211405,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } @@ -211377,6 +211426,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] > 0 ){ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ @@ -211410,7 +211482,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } @@ -211557,6 +211629,135 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ +/************** Begin file vector1bit.c **************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -211607,13 +211808,14 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ /* #include "sqliteInt.h" */ /* #include "vectorIndexInt.h" */ -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else @@ -211639,9 +211841,18 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -211657,14 +211868,18 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -211675,6 +211890,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -211691,6 +211910,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -211870,7 +212096,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -211879,20 +212105,25 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -211902,7 +212133,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -211917,7 +212148,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -211936,6 +212167,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -211970,6 +212202,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -211980,8 +212213,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -211996,10 +212229,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -212023,24 +212257,36 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ + return SQLITE_ERROR; + } + } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( neighbours == 0 ){ + neighbours = type; + } + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); if( maxNeighborsParam == 0 ){ // 3 D**(1/2) gives good recall values (90%+) // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ return SQLITE_ERROR; } if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ - return SQLITE_ERROR; - } - } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -212355,6 +212601,83 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -212391,16 +212714,24 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -212411,6 +212742,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -212434,6 +212771,9 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -212475,7 +212815,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -212484,56 +212826,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; - } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); + } pCtx->nUnvisited++; } @@ -212563,7 +212900,14 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder, + float *pNodeToNew +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -212571,20 +212915,27 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -212602,12 +212953,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -212617,7 +212970,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -212625,14 +212978,17 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -212681,7 +213037,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -212698,8 +213054,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -212727,25 +213084,30 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -212822,7 +213184,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -212831,7 +213193,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -212839,9 +213201,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -212865,6 +213227,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -212877,12 +213242,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -212920,28 +213297,33 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; + float nodeToNew; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -212952,6 +213334,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -213003,7 +213387,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; @@ -213050,6 +213434,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -213096,11 +213481,20 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } @@ -213216,26 +213610,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -213342,32 +213716,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -213474,57 +213838,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -213603,32 +213916,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -214033,13 +214336,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -214439,7 +214743,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -214551,9 +214855,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); From 8441108e713d2aa67606cf91385fe82c066045d7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 18:31:38 +0400 Subject: [PATCH 048/121] allow vector index to be partial --- libsql-sqlite3/src/vectorIndex.c | 5 --- libsql-sqlite3/test/libsql_vector_index.test | 33 ++++++++++++++++++-- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index d8b3497781..001a1aae10 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -862,11 +862,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index c1a270e4da..a173c773d3 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -275,6 +275,36 @@ do_execsql_test vector-all-params { SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} +do_execsql_test vector-partial { + CREATE TABLE t_partial( name TEXT, type INT, v FLOAT32(3)); + INSERT INTO t_partial VALUES ( 'a', 0, vector('[1,2,3]') ); + INSERT INTO t_partial VALUES ( 'b', 1, vector('[3,4,5]') ); + INSERT INTO t_partial VALUES ( 'c', 2, vector('[4,5,6]') ); + INSERT INTO t_partial VALUES ( 'd', 0, vector('[5,6,7]') ); + INSERT INTO t_partial VALUES ( 'e', 1, vector('[6,7,8]') ); + INSERT INTO t_partial VALUES ( 'f', 2, vector('[7,8,9]') ); + CREATE INDEX t_partial_idx_0 ON t_partial( libsql_vector_idx(v) ) WHERE type = 0; + CREATE INDEX t_partial_idx_1 ON t_partial( libsql_vector_idx(v) ) WHERE type = 1; + CREATE INDEX t_partial_idx_not_0 ON t_partial( libsql_vector_idx(v) ) WHERE type != 0; + SELECT id FROM vector_top_k('t_partial_idx_0', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_1', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_not_0', vector('[1,2,3]'), 10); + INSERT INTO t_partial VALUES ( 'g', 0, vector('[8,9,10]') ); + INSERT INTO t_partial VALUES ( 'h', 1, vector('[9,10,11]') ); + INSERT INTO t_partial VALUES ( 'i', 2, vector('[10,11,12]') ); + SELECT id FROM vector_top_k('t_partial_idx_0', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_1', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_not_0', vector('[1,2,3]'), 10); +} { + 1 4 + 2 5 + 2 3 5 6 + + 1 4 7 + 2 5 8 + 2 3 5 6 8 9 +} + proc error_messages {sql} { set ret "" catch { @@ -309,8 +339,6 @@ do_test vector-errors { sqlite3_exec db { CREATE TABLE t_mixed_t( v FLOAT32(3)); } sqlite3_exec db { INSERT INTO t_mixed_t VALUES('[1]'); } lappend ret [error_messages {CREATE INDEX t_mixed_t_idx ON t_mixed_t( libsql_vector_idx(v) )}] - sqlite3_exec db { CREATE TABLE t_partial( name TEXT, type INT, v FLOAT32(3)); } - lappend ret [error_messages {CREATE INDEX t_partial_idx ON t_partial( libsql_vector_idx(v) ) WHERE type = 0}] } [list {*}{ {no such table: main.t_no} {no such column: v} @@ -328,5 +356,4 @@ do_test vector-errors { {vector index(insert): only f32 vectors are supported} {vector index(search): dimensions are different: 2 != 4} {vector index(insert): dimensions are different: 1 != 3} - {vector index: where condition is forbidden} }] From ec996fabf462d841b9256f032c102db7d2a95e79 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 19:40:55 +0400 Subject: [PATCH 049/121] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 5 ----- libsql-ffi/bundled/src/sqlite3.c | 5 ----- 2 files changed, 10 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 15d09606fb..3a76f9cff3 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -214523,11 +214523,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 15d09606fb..3a76f9cff3 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -214523,11 +214523,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; From 85f1182c0e8b9677e6364594ee5d4ca63d3c9bed Mon Sep 17 00:00:00 2001 From: Avinash Sajjanshetty Date: Fri, 9 Aug 2024 12:47:44 +0530 Subject: [PATCH 050/121] Document path based routing usage in multi tenant databases --- docs/USER_GUIDE.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/USER_GUIDE.md b/docs/USER_GUIDE.md index f9d03fa508..96e2089e98 100644 --- a/docs/USER_GUIDE.md +++ b/docs/USER_GUIDE.md @@ -237,6 +237,12 @@ For example, if you have the following entries in your `/etc/hosts` file: You can access `db1` with the `http://db1.local:8080`URL and `db2` with `http://db2.local:8080`. The database files for the databases are stored in `/dbs/db1` and ` Date: Fri, 9 Aug 2024 13:55:56 +0400 Subject: [PATCH 051/121] small fixes --- libsql-sqlite3/src/vector.c | 15 +++++++------- libsql-sqlite3/src/vector1bit.c | 5 +++-- libsql-sqlite3/src/vectorIndex.c | 4 ++-- libsql-sqlite3/src/vectorInt.h | 4 ++-- libsql-sqlite3/src/vectordiskann.c | 13 ++++++------ libsql-sqlite3/src/vectorfloat32.c | 33 +++--------------------------- libsql-sqlite3/src/vectorfloat64.c | 8 ++++++-- 7 files changed, 30 insertions(+), 52 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 73b84b047e..c622d977e1 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -42,6 +42,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -252,7 +253,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -362,14 +363,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -531,7 +532,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -581,7 +582,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -636,12 +637,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 66da59f76a..f4fd5f9100 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -41,10 +41,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 5a13b4ea60..983e94cb9f 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -261,7 +261,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -982,7 +982,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 84cf9c0d1f..efe8f3cf38 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -42,7 +42,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -97,7 +97,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 88151bf1df..ae832f4400 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -47,7 +47,6 @@ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "math.h" @@ -84,7 +83,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -966,15 +966,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -987,6 +985,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 5d6641991c..d53d10d593 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -41,10 +41,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -94,34 +95,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index 1d29c9c3d6..885306c8c6 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -38,10 +38,14 @@ void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** From 2aca12b6e13c0d9da3e10afede0842cf38917dc3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 9 Aug 2024 13:57:36 +0400 Subject: [PATCH 052/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 82 +++++++------------ libsql-ffi/bundled/src/sqlite3.c | 82 +++++++------------ 2 files changed, 60 insertions(+), 104 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index eff02cf87a..6c60bccaab 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85264,7 +85264,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -85319,7 +85319,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -210988,6 +210988,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211198,7 +211199,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -211308,14 +211309,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -211477,7 +211478,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -211527,7 +211528,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211582,12 +211583,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211673,10 +211674,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -211808,7 +211810,6 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ @@ -211845,7 +211846,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -212727,15 +212729,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -212748,6 +212748,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } @@ -213557,10 +213558,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -213610,34 +213612,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, @@ -213778,10 +213752,14 @@ void vectorF32DeserializeFromBlob( void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -214201,7 +214179,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -214922,7 +214900,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index eff02cf87a..6c60bccaab 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85264,7 +85264,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -85319,7 +85319,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -210988,6 +210988,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211198,7 +211199,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -211308,14 +211309,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -211477,7 +211478,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -211527,7 +211528,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211582,12 +211583,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211673,10 +211674,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -211808,7 +211810,6 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ @@ -211845,7 +211846,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -212727,15 +212729,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -212748,6 +212748,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } @@ -213557,10 +213558,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -213610,34 +213612,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, @@ -213778,10 +213752,14 @@ void vectorF32DeserializeFromBlob( void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -214201,7 +214179,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -214922,7 +214900,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } From 7fe8f96964711dec9c724c32abb5c6665a3c3822 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 2 Aug 2024 13:23:10 -0700 Subject: [PATCH 053/121] c: add replicated data for sync This adds a new replicated struct that is output during sync for the C bindings. --- bindings/c/include/libsql.h | 7 ++++++- bindings/c/src/lib.rs | 9 +++++++-- bindings/c/src/types.rs | 6 ++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/bindings/c/include/libsql.h b/bindings/c/include/libsql.h index 8178980466..c1285d4fff 100644 --- a/bindings/c/include/libsql.h +++ b/bindings/c/include/libsql.h @@ -27,6 +27,11 @@ typedef struct libsql_stmt libsql_stmt; typedef const libsql_database *libsql_database_t; +typedef struct { + uintptr_t frame_no; + uintptr_t frames_synced; +} replicated; + typedef struct { const char *db_path; const char *primary_url; @@ -56,7 +61,7 @@ typedef struct { extern "C" { #endif // __cplusplus -int libsql_sync(libsql_database_t db, const char **out_err_msg); +int libsql_sync(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); int libsql_open_sync(const char *db_path, const char *primary_url, diff --git a/bindings/c/src/lib.rs b/bindings/c/src/lib.rs index 96e4effd2a..3376c4e78b 100644 --- a/bindings/c/src/lib.rs +++ b/bindings/c/src/lib.rs @@ -11,7 +11,7 @@ use tokio::runtime::Runtime; use types::{ blob, libsql_connection, libsql_connection_t, libsql_database, libsql_database_t, libsql_row, libsql_row_t, libsql_rows, libsql_rows_future_t, libsql_rows_t, libsql_stmt, libsql_stmt_t, - stmt, + replicated, stmt, }; lazy_static! { @@ -34,11 +34,16 @@ unsafe fn set_err_msg(msg: String, output: *mut *const std::ffi::c_char) { #[no_mangle] pub unsafe extern "C" fn libsql_sync( db: libsql_database_t, + out_replicated: *mut replicated, out_err_msg: *mut *const std::ffi::c_char, ) -> std::ffi::c_int { let db = db.get_ref(); match RT.block_on(db.sync()) { - Ok(_) => 0, + Ok(replicated) => { + (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as usize; + (*out_replicated).frames_synced = replicated.frames_synced() as usize; + 0 + } Err(e) => { set_err_msg(format!("Error syncing database: {e}"), out_err_msg); 1 diff --git a/bindings/c/src/types.rs b/bindings/c/src/types.rs index 9f818e28d4..2ec399973a 100644 --- a/bindings/c/src/types.rs +++ b/bindings/c/src/types.rs @@ -115,6 +115,12 @@ impl From<&mut libsql_connection> for libsql_connection_t { } } +#[repr(C)] +pub struct replicated { + pub frame_no: usize, + pub frames_synced: usize, +} + pub struct stmt { pub stmt: libsql::Statement, pub params: Vec, From fb852624188cd78babfc1463f631dd870e9b826c Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 6 Aug 2024 15:02:25 -0400 Subject: [PATCH 054/121] c: rename sync method to to avoid breaking change --- bindings/c/include/libsql.h | 8 +++++--- bindings/c/src/lib.rs | 22 ++++++++++++++++++++-- bindings/c/src/types.rs | 4 ++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/bindings/c/include/libsql.h b/bindings/c/include/libsql.h index c1285d4fff..7fdfb4b3c0 100644 --- a/bindings/c/include/libsql.h +++ b/bindings/c/include/libsql.h @@ -28,8 +28,8 @@ typedef struct libsql_stmt libsql_stmt; typedef const libsql_database *libsql_database_t; typedef struct { - uintptr_t frame_no; - uintptr_t frames_synced; + int frame_no; + int frames_synced; } replicated; typedef struct { @@ -61,7 +61,9 @@ typedef struct { extern "C" { #endif // __cplusplus -int libsql_sync(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); +int libsql_sync(libsql_database_t db, const char **out_err_msg); + +int libsql_sync2(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); int libsql_open_sync(const char *db_path, const char *primary_url, diff --git a/bindings/c/src/lib.rs b/bindings/c/src/lib.rs index 3376c4e78b..6cb1dc096e 100644 --- a/bindings/c/src/lib.rs +++ b/bindings/c/src/lib.rs @@ -33,6 +33,21 @@ unsafe fn set_err_msg(msg: String, output: *mut *const std::ffi::c_char) { #[no_mangle] pub unsafe extern "C" fn libsql_sync( + db: libsql_database_t, + out_err_msg: *mut *const std::ffi::c_char, +) -> std::ffi::c_int { + let db = db.get_ref(); + match RT.block_on(db.sync()) { + Ok(_) => 0, + Err(e) => { + set_err_msg(format!("Error syncing database: {e}"), out_err_msg); + 1 + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn libsql_sync2( db: libsql_database_t, out_replicated: *mut replicated, out_err_msg: *mut *const std::ffi::c_char, @@ -40,8 +55,11 @@ pub unsafe extern "C" fn libsql_sync( let db = db.get_ref(); match RT.block_on(db.sync()) { Ok(replicated) => { - (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as usize; - (*out_replicated).frames_synced = replicated.frames_synced() as usize; + if !out_replicated.is_null() { + (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as i32; + (*out_replicated).frames_synced = replicated.frames_synced() as i32; + } + 0 } Err(e) => { diff --git a/bindings/c/src/types.rs b/bindings/c/src/types.rs index 2ec399973a..5d9f0b517f 100644 --- a/bindings/c/src/types.rs +++ b/bindings/c/src/types.rs @@ -117,8 +117,8 @@ impl From<&mut libsql_connection> for libsql_connection_t { #[repr(C)] pub struct replicated { - pub frame_no: usize, - pub frames_synced: usize, + pub frame_no: std::ffi::c_int, + pub frames_synced: std::ffi::c_int, } pub struct stmt { From 7c4ea18c75598e7cd2b84d2940d9c6a0f380d571 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 10:22:19 +0200 Subject: [PATCH 055/121] abstract replicator injector and introduce SqliteInjector --- bottomless/src/replicator.rs | 9 +- libsql-replication/src/injector/error.rs | 2 + libsql-replication/src/injector/mod.rs | 298 +-------------- .../injector/{ => sqlite_injector}/headers.rs | 0 .../{ => sqlite_injector}/injector_wal.rs | 0 .../src/injector/sqlite_injector/mod.rs | 345 ++++++++++++++++++ libsql-replication/src/replicator.rs | 95 ++--- libsql/src/replication/mod.rs | 7 +- 8 files changed, 418 insertions(+), 338 deletions(-) rename libsql-replication/src/injector/{ => sqlite_injector}/headers.rs (100%) rename libsql-replication/src/injector/{ => sqlite_injector}/injector_wal.rs (100%) create mode 100644 libsql-replication/src/injector/sqlite_injector/mod.rs diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index f2ef812f75..26e190df66 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -17,6 +17,7 @@ use aws_sdk_s3::primitives::ByteStream; use aws_sdk_s3::{Client, Config}; use bytes::{Buf, Bytes}; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; +use libsql_replication::injector::Injector as _; use libsql_sys::{Cipher, EncryptionConfig}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -1449,12 +1450,12 @@ impl Replicator { db_path: &Path, ) -> Result { let encryption_config = self.encryption_config.clone(); - let mut injector = libsql_replication::injector::Injector::new( - db_path, + let mut injector = libsql_replication::injector::SqliteInjector::new( + db_path.to_path_buf(), 4096, libsql_sys::connection::NO_AUTOCHECKPOINT, encryption_config, - )?; + ).await?; let prefix = format!("{}-{}/", self.db_name, generation); let mut page_buf = { let mut v = Vec::with_capacity(page_size); @@ -1552,7 +1553,7 @@ impl Replicator { }, page_buf.as_slice(), ); - injector.inject_frame(frame_to_inject)?; + injector.inject_frame(frame_to_inject).await?; applied_wal_frame = true; } } diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index 14899089ea..b1cebfe28b 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,3 +1,5 @@ +pub type Result = std::result::Result; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("IO error: {0}")] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 80443964fe..1d69ae0aab 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,299 +1,27 @@ -use std::path::Path; -use std::sync::Arc; -use std::{collections::VecDeque, path::PathBuf}; +use std::future::Future; -use parking_lot::Mutex; -use rusqlite::OpenFlags; +pub use sqlite_injector::SqliteInjector; use crate::frame::{Frame, FrameNo}; +use error::Result; pub use error::Error; -use self::injector_wal::{ - InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, -}; - mod error; -mod headers; -mod injector_wal; - -#[derive(Debug)] -pub enum InjectError {} - -pub type FrameBuffer = Arc>>; - -pub struct Injector { - /// The injector is in a transaction state - is_txn: bool, - /// Buffer for holding current transaction frames - buffer: FrameBuffer, - /// Maximum capacity of the frame buffer - capacity: usize, - /// Injector connection - // connection must be dropped before the hook context - connection: Arc>>, - biggest_uncommitted_seen: FrameNo, - - // Connection config items used to recreate the injection connection - path: PathBuf, - encryption_config: Option, - auto_checkpoint: u32, -} - -/// Methods from this trait are called before and after performing a frame injection. -/// This trait trait is used to record the last committed frame_no to the log. -/// The implementer can persist the pre and post commit frame no, and compare them in the event of -/// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -impl Injector { - pub fn new( - path: impl AsRef, - capacity: usize, - auto_checkpoint: u32, - encryption_config: Option, - ) -> Result { - let path = path.as_ref().to_path_buf(); - - let buffer = FrameBuffer::default(); - let wal_manager = InjectorWalManager::new(buffer.clone()); - let connection = libsql_sys::Connection::open( - &path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_CREATE - | OpenFlags::SQLITE_OPEN_URI - | OpenFlags::SQLITE_OPEN_NO_MUTEX, - wal_manager, - auto_checkpoint, - encryption_config.clone(), - )?; - - Ok(Self { - is_txn: false, - buffer, - capacity, - connection: Arc::new(Mutex::new(connection)), - biggest_uncommitted_seen: 0, - - path, - encryption_config, - auto_checkpoint, - }) - } - - /// Inject a frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). - pub fn inject_frame(&mut self, frame: Frame) -> Result, Error> { - let frame_close_txn = frame.header().size_after.get() != 0; - self.buffer.lock().push_back(frame); - if frame_close_txn || self.buffer.lock().len() >= self.capacity { - return self.flush(); - } +mod sqlite_injector; - Ok(None) - } +pub trait Injector { + /// Inject a singular frame. + fn inject_frame( + &mut self, + frame: Frame, + ) -> impl Future>> + Send; - pub fn rollback(&mut self) { - let conn = self.connection.lock(); - let mut rollback = conn.prepare_cached("ROLLBACK").unwrap(); - let _ = rollback.execute(()); - self.is_txn = false; - } + /// Discard any uncommintted frames. + fn rollback(&mut self) -> impl Future + Send; /// Flush the buffer to libsql WAL. /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame /// are then injected into the wal. - pub fn flush(&mut self) -> Result, Error> { - match self.try_flush() { - Err(e) => { - // something went wrong, rollback the connection to make sure we can retry in a - // clean state - self.biggest_uncommitted_seen = 0; - self.rollback(); - Err(e) - } - Ok(ret) => Ok(ret), - } - } - - fn try_flush(&mut self) -> Result, Error> { - if !self.is_txn { - self.begin_txn()?; - } - - let lock = self.buffer.lock(); - // the frames in the buffer are either monotonically increasing (log) or decreasing - // (snapshot). Either way, we want to find the biggest frameno we're about to commit, and - // that is either the front or the back of the buffer - let last_frame_no = match lock.back().zip(lock.front()) { - Some((b, f)) => f.header().frame_no.get().max(b.header().frame_no.get()), - None => { - tracing::trace!("nothing to inject"); - return Ok(None); - } - }; - - self.biggest_uncommitted_seen = self.biggest_uncommitted_seen.max(last_frame_no); - - drop(lock); - - let connection = self.connection.lock(); - // use prepare cached to avoid parsing the same statement over and over again. - let mut stmt = - connection.prepare_cached("INSERT INTO libsql_temp_injection VALUES (42)")?; - - // We execute the statement, and then force a call to xframe if necesacary. If the execute - // succeeds, then xframe wasn't called, in this case, we call cache_flush, and then process - // the error. - // It is unexpected that execute flushes, but it is possible, so we handle that case. - match stmt.execute(()).and_then(|_| connection.cache_flush()) { - Ok(_) => panic!("replication hook was not called"), - Err(e) => { - if let Some(e) = e.sqlite_error() { - if e.extended_code == LIBSQL_INJECT_OK { - // refresh schema - connection.pragma_update(None, "writable_schema", "reset")?; - let mut rollback = connection.prepare_cached("ROLLBACK")?; - let _ = rollback.execute(()); - self.is_txn = false; - assert!(self.buffer.lock().is_empty()); - let commit_frame_no = self.biggest_uncommitted_seen; - self.biggest_uncommitted_seen = 0; - return Ok(Some(commit_frame_no)); - } else if e.extended_code == LIBSQL_INJECT_OK_TXN { - self.is_txn = true; - assert!(self.buffer.lock().is_empty()); - return Ok(None); - } else if e.extended_code == LIBSQL_INJECT_FATAL { - return Err(Error::FatalInjectError); - } - } - - Err(Error::FatalInjectError) - } - } - } - - fn begin_txn(&mut self) -> Result<(), Error> { - let mut conn = self.connection.lock(); - - { - let wal_manager = InjectorWalManager::new(self.buffer.clone()); - let new_conn = libsql_sys::Connection::open( - &self.path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_CREATE - | OpenFlags::SQLITE_OPEN_URI - | OpenFlags::SQLITE_OPEN_NO_MUTEX, - wal_manager, - self.auto_checkpoint, - self.encryption_config.clone(), - )?; - - let _ = std::mem::replace(&mut *conn, new_conn); - } - - conn.pragma_update(None, "writable_schema", "true")?; - - let mut stmt = conn.prepare_cached("BEGIN IMMEDIATE")?; - stmt.execute(())?; - // we create a dummy table. This table MUST not be persisted, otherwise the replica schema - // would differ with the primary's. - let mut stmt = - conn.prepare_cached("CREATE TABLE IF NOT EXISTS libsql_temp_injection (x)")?; - stmt.execute(())?; - - Ok(()) - } - - pub fn clear_buffer(&mut self) { - self.buffer.lock().clear() - } - - #[cfg(test)] - pub fn is_txn(&self) -> bool { - self.is_txn - } -} - -#[cfg(test)] -mod test { - use crate::frame::FrameBorrowed; - use std::mem::size_of; - - use super::*; - /// this this is generated by creating a table test, inserting 5 rows into it, and then - /// truncating the wal file of it's header. - const WAL: &[u8] = include_bytes!("../../assets/test/test_wallog"); - - fn wal_log() -> impl Iterator { - WAL.chunks(size_of::()) - .map(|b| Frame::try_from(b).unwrap()) - } - - #[test] - fn test_simple_inject_frames() { - let temp = tempfile::tempdir().unwrap(); - - let mut injector = Injector::new(temp.path().join("data"), 10, 10000, None).unwrap(); - let log = wal_log(); - for frame in log { - injector.inject_frame(frame).unwrap(); - } - - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - - conn.query_row("SELECT COUNT(*) FROM test", (), |row| { - assert_eq!(row.get::<_, usize>(0).unwrap(), 5); - Ok(()) - }) - .unwrap(); - } - - #[test] - fn test_inject_frames_split_txn() { - let temp = tempfile::tempdir().unwrap(); - - // inject one frame at a time - let mut injector = Injector::new(temp.path().join("data"), 1, 10000, None).unwrap(); - let log = wal_log(); - for frame in log { - injector.inject_frame(frame).unwrap(); - } - - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - - conn.query_row("SELECT COUNT(*) FROM test", (), |row| { - assert_eq!(row.get::<_, usize>(0).unwrap(), 5); - Ok(()) - }) - .unwrap(); - } - - #[test] - fn test_inject_partial_txn_isolated() { - let temp = tempfile::tempdir().unwrap(); - - // inject one frame at a time - let mut injector = Injector::new(temp.path().join("data"), 10, 1000, None).unwrap(); - let mut frames = wal_log(); - - assert!(injector - .inject_frame(frames.next().unwrap()) - .unwrap() - .is_none()); - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - assert!(conn - .query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) - .is_err()); - - while injector - .inject_frame(frames.next().unwrap()) - .unwrap() - .is_none() - {} - - // reset schema - conn.pragma_update(None, "writable_schema", "reset") - .unwrap(); - conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) - .unwrap(); - } + fn flush(&mut self) -> impl Future>> + Send; } diff --git a/libsql-replication/src/injector/headers.rs b/libsql-replication/src/injector/sqlite_injector/headers.rs similarity index 100% rename from libsql-replication/src/injector/headers.rs rename to libsql-replication/src/injector/sqlite_injector/headers.rs diff --git a/libsql-replication/src/injector/injector_wal.rs b/libsql-replication/src/injector/sqlite_injector/injector_wal.rs similarity index 100% rename from libsql-replication/src/injector/injector_wal.rs rename to libsql-replication/src/injector/sqlite_injector/injector_wal.rs diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs new file mode 100644 index 0000000000..dea78ce4b5 --- /dev/null +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -0,0 +1,345 @@ +use std::path::Path; +use std::sync::Arc; +use std::{collections::VecDeque, path::PathBuf}; + +use parking_lot::Mutex; +use rusqlite::OpenFlags; +use tokio::task::spawn_blocking; + +use crate::frame::{Frame, FrameNo}; + +use self::injector_wal::{ + InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, +}; + +use super::error::Result; +use super::{Error, Injector}; + +mod headers; +mod injector_wal; + +pub type FrameBuffer = Arc>>; + +pub struct SqliteInjector { + pub(in super::super) inner: Arc>, +} + +impl Injector for SqliteInjector { + async fn inject_frame( + &mut self, + frame: Frame, + ) -> Result> { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().inject_frame(frame) + }).await.unwrap() + } + + async fn rollback(&mut self) { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().rollback() + }).await.unwrap(); + } + + async fn flush(&mut self) -> Result> { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().flush() + }).await.unwrap() + } +} + +impl SqliteInjector { + pub async fn new( + path: PathBuf, + capacity: usize, + auto_checkpoint: u32, + encryption_config: Option, + ) ->super::Result { + let inner = spawn_blocking(move || { + SqliteInjectorInner::new(path, capacity, auto_checkpoint, encryption_config) + }).await.unwrap()?; + + Ok(Self { + inner: Arc::new(Mutex::new(inner)) + }) + } +} + +pub(in super::super) struct SqliteInjectorInner { + /// The injector is in a transaction state + is_txn: bool, + /// Buffer for holding current transaction frames + buffer: FrameBuffer, + /// Maximum capacity of the frame buffer + capacity: usize, + /// Injector connection + // connection must be dropped before the hook context + connection: Arc>>, + biggest_uncommitted_seen: FrameNo, + + // Connection config items used to recreate the injection connection + path: PathBuf, + encryption_config: Option, + auto_checkpoint: u32, +} + +/// Methods from this trait are called before and after performing a frame injection. +/// This trait trait is used to record the last committed frame_no to the log. +/// The implementer can persist the pre and post commit frame no, and compare them in the event of +/// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. +impl SqliteInjectorInner { + fn new( + path: impl AsRef, + capacity: usize, + auto_checkpoint: u32, + encryption_config: Option, + ) -> Result { + let path = path.as_ref().to_path_buf(); + + let buffer = FrameBuffer::default(); + let wal_manager = InjectorWalManager::new(buffer.clone()); + let connection = libsql_sys::Connection::open( + &path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + wal_manager, + auto_checkpoint, + encryption_config.clone(), + )?; + + Ok(Self { + is_txn: false, + buffer, + capacity, + connection: Arc::new(Mutex::new(connection)), + biggest_uncommitted_seen: 0, + + path, + encryption_config, + auto_checkpoint, + }) + } + + /// Inject a frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). + pub fn inject_frame(&mut self, frame: Frame) -> Result, Error> { + let frame_close_txn = frame.header().size_after.get() != 0; + self.buffer.lock().push_back(frame); + if frame_close_txn || self.buffer.lock().len() >= self.capacity { + return self.flush(); + } + + Ok(None) + } + + pub fn rollback(&mut self) { + self.clear_buffer(); + let conn = self.connection.lock(); + let mut rollback = conn.prepare_cached("ROLLBACK").unwrap(); + let _ = rollback.execute(()); + self.is_txn = false; + } + + /// Flush the buffer to libsql WAL. + /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame + /// are then injected into the wal. + pub fn flush(&mut self) -> Result, Error> { + match self.try_flush() { + Err(e) => { + // something went wrong, rollback the connection to make sure we can retry in a + // clean state + self.biggest_uncommitted_seen = 0; + self.rollback(); + Err(e) + } + Ok(ret) => Ok(ret), + } + } + + fn try_flush(&mut self) -> Result, Error> { + if !self.is_txn { + self.begin_txn()?; + } + + let lock = self.buffer.lock(); + // the frames in the buffer are either monotonically increasing (log) or decreasing + // (snapshot). Either way, we want to find the biggest frameno we're about to commit, and + // that is either the front or the back of the buffer + let last_frame_no = match lock.back().zip(lock.front()) { + Some((b, f)) => f.header().frame_no.get().max(b.header().frame_no.get()), + None => { + tracing::trace!("nothing to inject"); + return Ok(None); + } + }; + + self.biggest_uncommitted_seen = self.biggest_uncommitted_seen.max(last_frame_no); + + drop(lock); + + let connection = self.connection.lock(); + // use prepare cached to avoid parsing the same statement over and over again. + let mut stmt = + connection.prepare_cached("INSERT INTO libsql_temp_injection VALUES (42)")?; + + // We execute the statement, and then force a call to xframe if necesacary. If the execute + // succeeds, then xframe wasn't called, in this case, we call cache_flush, and then process + // the error. + // It is unexpected that execute flushes, but it is possible, so we handle that case. + match stmt.execute(()).and_then(|_| connection.cache_flush()) { + Ok(_) => panic!("replication hook was not called"), + Err(e) => { + if let Some(e) = e.sqlite_error() { + if e.extended_code == LIBSQL_INJECT_OK { + // refresh schema + connection.pragma_update(None, "writable_schema", "reset")?; + let mut rollback = connection.prepare_cached("ROLLBACK")?; + let _ = rollback.execute(()); + self.is_txn = false; + assert!(self.buffer.lock().is_empty()); + let commit_frame_no = self.biggest_uncommitted_seen; + self.biggest_uncommitted_seen = 0; + return Ok(Some(commit_frame_no)); + } else if e.extended_code == LIBSQL_INJECT_OK_TXN { + self.is_txn = true; + assert!(self.buffer.lock().is_empty()); + return Ok(None); + } else if e.extended_code == LIBSQL_INJECT_FATAL { + return Err(Error::FatalInjectError); + } + } + + Err(Error::FatalInjectError) + } + } + } + + fn begin_txn(&mut self) -> Result<(), Error> { + let mut conn = self.connection.lock(); + + { + let wal_manager = InjectorWalManager::new(self.buffer.clone()); + let new_conn = libsql_sys::Connection::open( + &self.path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + wal_manager, + self.auto_checkpoint, + self.encryption_config.clone(), + )?; + + let _ = std::mem::replace(&mut *conn, new_conn); + } + + conn.pragma_update(None, "writable_schema", "true")?; + + let mut stmt = conn.prepare_cached("BEGIN IMMEDIATE")?; + stmt.execute(())?; + // we create a dummy table. This table MUST not be persisted, otherwise the replica schema + // would differ with the primary's. + let mut stmt = + conn.prepare_cached("CREATE TABLE IF NOT EXISTS libsql_temp_injection (x)")?; + stmt.execute(())?; + + Ok(()) + } + + pub fn clear_buffer(&mut self) { + self.buffer.lock().clear() + } + + #[cfg(test)] + pub fn is_txn(&self) -> bool { + self.is_txn + } +} + +#[cfg(test)] +mod test { + use crate::frame::FrameBorrowed; + use std::mem::size_of; + + use super::*; + /// this this is generated by creating a table test, inserting 5 rows into it, and then + /// truncating the wal file of it's header. + const WAL: &[u8] = include_bytes!("../../../assets/test/test_wallog"); + + fn wal_log() -> impl Iterator { + WAL.chunks(size_of::()) + .map(|b| Frame::try_from(b).unwrap()) + } + + #[test] + fn test_simple_inject_frames() { + let temp = tempfile::tempdir().unwrap(); + + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); + let log = wal_log(); + for frame in log { + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_frames_split_txn() { + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); + let log = wal_log(); + for frame in log { + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_partial_txn_isolated() { + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); + let mut frames = wal_log(); + + assert!(injector + .inject_frame(frames.next().unwrap()) + .unwrap() + .is_none()); + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + assert!(conn + .query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .is_err()); + + while injector + .inject_frame(frames.next().unwrap()) + .unwrap() + .is_none() + {} + + // reset schema + conn.pragma_update(None, "writable_schema", "reset") + .unwrap(); + conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .unwrap(); + } +} diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index bc1eada7f8..31c766faad 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -1,14 +1,11 @@ use std::path::PathBuf; -use std::sync::Arc; -use parking_lot::Mutex; -use tokio::task::spawn_blocking; use tokio::time::Duration; use tokio_stream::{Stream, StreamExt}; use tonic::{Code, Status}; use crate::frame::{Frame, FrameNo}; -use crate::injector::Injector; +use crate::injector::{Injector, SqliteInjector}; use crate::rpc::replication::{ Frame as RpcFrame, NAMESPACE_DOESNT_EXIST, NEED_SNAPSHOT_ERROR_MSG, NO_HELLO_ERROR_MSG, }; @@ -137,9 +134,9 @@ where /// The `Replicator`'s duty is to download frames from the primary, and pass them to the injector at /// transaction boundaries. -pub struct Replicator { +pub struct Replicator { client: C, - injector: Arc>, + injector: I, state: ReplicatorState, frames_synced: usize, } @@ -154,33 +151,42 @@ enum ReplicatorState { Exit, } -impl Replicator { +impl Replicator +where + C: ReplicatorClient, +{ /// Creates a replicator for the db file pointed at by `db_path` - pub async fn new( + pub async fn new_sqlite( client: C, db_path: PathBuf, auto_checkpoint: u32, encryption_config: Option, ) -> Result { - let injector = { - let db_path = db_path.clone(); - spawn_blocking(move || { - Injector::new( - db_path, - INJECTOR_BUFFER_CAPACITY, - auto_checkpoint, - encryption_config, - ) - }) - .await?? - }; + let injector = SqliteInjector::new( + db_path.clone(), + INJECTOR_BUFFER_CAPACITY, + auto_checkpoint, + encryption_config, + ) + .await?; + + Ok(Self::new(client, injector)) + } +} - Ok(Self { +impl Replicator +where + C: ReplicatorClient, + I: Injector, +{ + + pub fn new(client: C, injector: I) -> Self { + Self { client, - injector: Arc::new(Mutex::new(injector)), + injector, state: ReplicatorState::NeedHandshake, frames_synced: 0, - }) + } } /// for a handshake on next call to replicate. @@ -250,7 +256,7 @@ impl Replicator { // in case of error we rollback the current injector transaction, and start over. if ret.is_err() { self.client.rollback(); - self.injector.lock().rollback(); + self.injector.rollback().await; } self.state = match ret { @@ -293,7 +299,8 @@ impl Replicator { } async fn load_snapshot(&mut self) -> Result<(), Error> { - self.injector.lock().clear_buffer(); + self.client.rollback(); + self.injector.rollback().await; loop { match self.client.snapshot().await { Ok(mut stream) => { @@ -315,26 +322,22 @@ impl Replicator { async fn inject_frame(&mut self, frame: Frame) -> Result<(), Error> { self.frames_synced += 1; - let injector = self.injector.clone(); - match spawn_blocking(move || injector.lock().inject_frame(frame)).await? { - Ok(Some(commit_fno)) => { + match self.injector.inject_frame(frame).await? { + Some(commit_fno) => { self.client.commit_frame_no(commit_fno).await?; } - Ok(None) => (), - Err(e) => Err(e)?, + None => (), } Ok(()) } pub async fn flush(&mut self) -> Result<(), Error> { - let injector = self.injector.clone(); - match spawn_blocking(move || injector.lock().flush()).await? { - Ok(Some(commit_fno)) => { + match self.injector.flush().await? { + Some(commit_fno) => { self.client.commit_frame_no(commit_fno).await?; } - Ok(None) => (), - Err(e) => Err(e)?, + None => (), } Ok(()) @@ -395,7 +398,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); @@ -438,7 +441,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -482,7 +485,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -526,7 +529,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -568,7 +571,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -610,7 +613,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); replicator.state = ReplicatorState::NeedSnapshot; @@ -653,7 +656,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -696,7 +699,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); replicator.state = ReplicatorState::NeedHandshake; @@ -784,7 +787,7 @@ mod test { committed_frame_no: None, }; - let mut replicator = Replicator::new(client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); @@ -795,7 +798,7 @@ mod test { replicator.try_replicate_step().await.unwrap_err(), Error::Client(_) )); - assert!(!replicator.injector.lock().is_txn()); + assert!(!replicator.injector.inner.lock().is_txn()); assert!(replicator.client_mut().committed_frame_no.is_none()); assert_eq!(replicator.state, ReplicatorState::NeedHandshake); @@ -805,7 +808,7 @@ mod test { replicator.client_mut().should_error = false; replicator.try_replicate_step().await.unwrap(); - assert!(!replicator.injector.lock().is_txn()); + assert!(!replicator.injector.inner.lock().is_txn()); assert_eq!(replicator.state, ReplicatorState::Exit); assert_eq!(replicator.client_mut().committed_frame_no, Some(6)); } diff --git a/libsql/src/replication/mod.rs b/libsql/src/replication/mod.rs index 69cc0b5db2..2f4e9b49c0 100644 --- a/libsql/src/replication/mod.rs +++ b/libsql/src/replication/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use std::time::Duration; pub use libsql_replication::frame::{Frame, FrameNo}; +use libsql_replication::injector::SqliteInjector; use libsql_replication::replicator::{Either, Replicator}; pub use libsql_replication::snapshot::SnapshotFile; @@ -129,7 +130,7 @@ impl Writer { #[derive(Clone)] pub(crate) struct EmbeddedReplicator { - replicator: Arc>>>, + replicator: Arc, SqliteInjector>>>, bg_abort: Option>, last_frames_synced: Arc, } @@ -149,7 +150,7 @@ impl EmbeddedReplicator { perodic_sync: Option, ) -> Result { let replicator = Arc::new(Mutex::new( - Replicator::new( + Replicator::new_sqlite( Either::Left(client), db_path, auto_checkpoint, @@ -193,7 +194,7 @@ impl EmbeddedReplicator { encryption_config: Option, ) -> Result { let replicator = Arc::new(Mutex::new( - Replicator::new( + Replicator::new_sqlite( Either::Right(client), db_path, auto_checkpoint, From 4b5baacad57965b639dcab274a4a97f8b0ed3abd Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 11:57:15 +0200 Subject: [PATCH 056/121] introduce libsql injector --- Cargo.lock | 2 + Cargo.toml | 1 + libsql-replication/Cargo.toml | 1 + .../proto/replication_log.proto | 6 +++ libsql-replication/src/frame.rs | 3 +- libsql-replication/src/generated/wal_log.rs | 42 +++++++++++++++ libsql-replication/src/injector/error.rs | 5 +- .../src/injector/libsql_injector.rs | 44 +++++++++++++++ libsql-replication/src/injector/mod.rs | 2 + .../src/injector/sqlite_injector/mod.rs | 12 ++--- libsql-replication/src/rpc.rs | 5 +- libsql-server/Cargo.toml | 2 +- .../src/replication/replicator_client.rs | 6 ++- libsql-server/src/rpc/replication_log.rs | 5 ++ libsql-wal/Cargo.toml | 3 +- libsql-wal/src/replication/injector.rs | 23 ++++---- libsql-wal/src/segment/current.rs | 4 +- libsql-wal/src/shared_wal.rs | 8 +-- libsql-wal/src/transaction.rs | 53 +++++++++++++++++-- libsql/src/replication/remote_client.rs | 3 +- 20 files changed, 196 insertions(+), 34 deletions(-) create mode 100644 libsql-replication/src/injector/libsql_injector.rs diff --git a/Cargo.lock b/Cargo.lock index 17cfc0e090..7e19e03e9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3702,6 +3702,7 @@ name = "libsql-wal" version = "0.1.0" dependencies = [ "arc-swap", + "async-lock 3.4.0", "async-stream", "aws-config 1.5.4", "aws-credential-types 1.2.0", @@ -3779,6 +3780,7 @@ dependencies = [ "cbc", "libsql-rusqlite", "libsql-sys", + "libsql-wal", "parking_lot", "prost", "prost-build", diff --git a/Cargo.toml b/Cargo.toml index 9381fb83f3..685f14964f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = ] } hyper = { version = "0.14" } tower = { version = "0.4.13" } +zerocopy = { version = "0.7.32", features = ["derive", "alloc"] } # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index d2a9431cba..2a03d362bf 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" tonic = { version = "0.11", features = ["tls"] } prost = "0.12" libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } +libsql-wal = { path = "../libsql-wal/" } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } diff --git a/libsql-replication/proto/replication_log.proto b/libsql-replication/proto/replication_log.proto index 6208874609..b3be419319 100644 --- a/libsql-replication/proto/replication_log.proto +++ b/libsql-replication/proto/replication_log.proto @@ -9,6 +9,12 @@ message LogOffset { message HelloRequest { optional uint64 handshake_version = 1; + enum WalFlavor { + Sqlite = 0; + Libsql = 1; + } + // the type of wal that the client is expecting + optional WalFlavor wal_flavor = 2; } message HelloResponse { diff --git a/libsql-replication/src/frame.rs b/libsql-replication/src/frame.rs index a6a2854e52..55b5b778b5 100644 --- a/libsql-replication/src/frame.rs +++ b/libsql-replication/src/frame.rs @@ -13,7 +13,6 @@ use crate::LIBSQL_PAGE_SIZE; pub type FrameNo = u64; /// The file header for the WAL log. All fields are represented in little-endian ordering. -/// See `encode` and `decode` for actual layout. // repr C for stable sizing #[repr(C)] #[derive(Debug, Clone, Copy, zerocopy::FromZeroes, zerocopy::FromBytes, zerocopy::AsBytes)] @@ -22,7 +21,7 @@ pub struct FrameHeader { pub frame_no: lu64, /// Rolling checksum of all the previous frames, including this one. pub checksum: lu64, - /// page number, if frame_type is FrameType::Page + /// page number pub page_no: lu32, /// Size of the database (in page) after committing the transaction. This is passed from sqlite, /// and serves as commit transaction boundary diff --git a/libsql-replication/src/generated/wal_log.rs b/libsql-replication/src/generated/wal_log.rs index 2d7330e732..441881c4a7 100644 --- a/libsql-replication/src/generated/wal_log.rs +++ b/libsql-replication/src/generated/wal_log.rs @@ -10,6 +10,48 @@ pub struct LogOffset { pub struct HelloRequest { #[prost(uint64, optional, tag = "1")] pub handshake_version: ::core::option::Option, + /// the type of wal that the client is expecting + #[prost(enumeration = "hello_request::WalFlavor", optional, tag = "2")] + pub wal_flavor: ::core::option::Option, +} +/// Nested message and enum types in `HelloRequest`. +pub mod hello_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum WalFlavor { + Sqlite = 0, + Libsql = 1, + } + impl WalFlavor { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + WalFlavor::Sqlite => "Sqlite", + WalFlavor::Libsql => "Libsql", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Sqlite" => Some(Self::Sqlite), + "Libsql" => Some(Self::Libsql), + _ => None, + } + } + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index b1cebfe28b..225960c4d1 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,4 +1,5 @@ pub type Result = std::result::Result; +pub type BoxError = Box; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -6,6 +7,6 @@ pub enum Error { Io(#[from] std::io::Error), #[error("SQLite error: {0}")] Sqlite(#[from] rusqlite::Error), - #[error("A fatal error occured injecting frames")] - FatalInjectError, + #[error("A fatal error occured injecting frames: {0}")] + FatalInjectError(BoxError), } diff --git a/libsql-replication/src/injector/libsql_injector.rs b/libsql-replication/src/injector/libsql_injector.rs new file mode 100644 index 0000000000..946d35e547 --- /dev/null +++ b/libsql-replication/src/injector/libsql_injector.rs @@ -0,0 +1,44 @@ +use std::mem::size_of; + +use libsql_wal::io::StdIO; +use libsql_wal::replication::injector::Injector; +use libsql_wal::segment::Frame as WalFrame; +use zerocopy::{AsBytes, FromZeroes}; + +use crate::frame::{Frame, FrameNo}; + +use super::error::{Error, Result}; + +pub struct LibsqlInjector { + injector: Injector, +} + +impl super::Injector for LibsqlInjector { + async fn inject_frame(&mut self, frame: Frame) -> Result> { + // this is a bit annoying be we want to read the frame, and it has to be aligned, so we + // must copy it... + // FIXME: optimize this. + let mut wal_frame = WalFrame::new_box_zeroed(); + if frame.bytes().len() != size_of::() { + todo!("invalid frame"); + } + wal_frame.as_bytes_mut().copy_from_slice(&frame.bytes()[..]); + Ok(self + .injector + .insert_frame(wal_frame) + .await + .map_err(|e| Error::FatalInjectError(e.into()))?) + } + + async fn rollback(&mut self) { + self.injector.rollback(); + } + + async fn flush(&mut self) -> Result> { + self.injector + .flush(None) + .await + .map_err(|e| Error::FatalInjectError(e.into()))?; + Ok(None) + } +} diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 1d69ae0aab..20a81cfa01 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,6 +1,7 @@ use std::future::Future; pub use sqlite_injector::SqliteInjector; +pub use libsql_injector::LibsqlInjector; use crate::frame::{Frame, FrameNo}; @@ -9,6 +10,7 @@ pub use error::Error; mod error; mod sqlite_injector; +mod libsql_injector; pub trait Injector { /// Inject a singular frame. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index dea78ce4b5..545fbe810d 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -192,8 +192,8 @@ impl SqliteInjectorInner { match stmt.execute(()).and_then(|_| connection.cache_flush()) { Ok(_) => panic!("replication hook was not called"), Err(e) => { - if let Some(e) = e.sqlite_error() { - if e.extended_code == LIBSQL_INJECT_OK { + if let Some(err) = e.sqlite_error() { + if err.extended_code == LIBSQL_INJECT_OK { // refresh schema connection.pragma_update(None, "writable_schema", "reset")?; let mut rollback = connection.prepare_cached("ROLLBACK")?; @@ -203,16 +203,16 @@ impl SqliteInjectorInner { let commit_frame_no = self.biggest_uncommitted_seen; self.biggest_uncommitted_seen = 0; return Ok(Some(commit_frame_no)); - } else if e.extended_code == LIBSQL_INJECT_OK_TXN { + } else if err.extended_code == LIBSQL_INJECT_OK_TXN { self.is_txn = true; assert!(self.buffer.lock().is_empty()); return Ok(None); - } else if e.extended_code == LIBSQL_INJECT_FATAL { - return Err(Error::FatalInjectError); + } else if err.extended_code == LIBSQL_INJECT_FATAL { + return Err(Error::FatalInjectError(e.into())); } } - Err(Error::FatalInjectError) + Err(Error::FatalInjectError(e.into())) } } } diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index ebc92cf10c..a9b172db20 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -25,6 +25,8 @@ pub mod replication { #![allow(clippy::all)] use uuid::Uuid; + + use self::hello_request::WalFlavor; include!("generated/wal_log.rs"); pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO"; @@ -46,9 +48,10 @@ pub mod replication { } impl HelloRequest { - pub fn new() -> Self { + pub fn new(wal_flavor: WalFlavor) -> Self { Self { handshake_version: Some(1), + wal_flavor: Some(wal_flavor.into()) } } } diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 6763c02dfb..934a400786 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -83,7 +83,7 @@ url = { version = "2.3", features = ["serde"] } uuid = { version = "1.3", features = ["v4", "serde", "v7"] } aes = { version = "0.8.3", optional = true } cbc = { version = "0.1.2", optional = true } -zerocopy = { version = "0.7.28", features = ["derive", "alloc"] } +zerocopy = { workspace = true } hashbrown = { version = "0.14.3", features = ["serde"] } hdrhistogram = "7.5.4" crossbeam = "0.8.4" diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index 4d12ff7f83..d68c259dc9 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -7,6 +7,7 @@ use futures::TryStreamExt; use libsql_replication::frame::Frame; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ verify_session_token, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, SESSION_TOKEN_KEY, @@ -35,6 +36,7 @@ pub struct Client { // the primary current replication index, as reported by the last handshake pub primary_replication_index: Option, store: NamespaceStore, + wal_flavor: WalFlavor, } impl Client { @@ -44,6 +46,7 @@ impl Client { path: &Path, meta_store_handle: MetaStoreHandle, store: NamespaceStore, + wal_flavor: WalFlavor, ) -> crate::Result { let (current_frame_no_notifier, _) = watch::channel(None); let meta = WalIndexMeta::open(path).await?; @@ -57,6 +60,7 @@ impl Client { meta_store_handle, primary_replication_index: None, store, + wal_flavor, }) } @@ -96,7 +100,7 @@ impl ReplicatorClient for Client { #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { tracing::debug!("Attempting to perform handshake with primary."); - let req = self.make_request(HelloRequest::new()); + let req = self.make_request(HelloRequest::new(self.wal_flavor)); let resp = self.client.hello(req).await?; let hello = resp.into_inner(); verify_session_token(&hello.session_token).map_err(Error::Client)?; diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index c0b216739e..628cb4a01d 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -7,6 +7,7 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; +use libsql_replication::rpc::replication::hello_request::WalFlavor; pub use libsql_replication::rpc::replication as rpc; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ @@ -355,6 +356,10 @@ impl ReplicationLog for ReplicationLogService { } } + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")) + } + let (logger, config, version, _, _) = self.logger_from_namespace(namespace, &req, false).await?; diff --git a/libsql-wal/Cargo.toml b/libsql-wal/Cargo.toml index 9624596c28..f24f2e4c59 100644 --- a/libsql-wal/Cargo.toml +++ b/libsql-wal/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] arc-swap = "1.7.1" async-stream = "0.3.5" +async-lock = "3.4.0" bitflags = "2.5.0" bytes = "1.6.0" chrono = "0.4.38" @@ -29,7 +30,7 @@ tokio-stream = "0.1.15" tracing = "0.1.40" uuid = { version = "1.8.0", features = ["v4"] } walkdir = "2.5.0" -zerocopy = { version = "0.7.32", features = ["derive", "alloc"] } +zerocopy = { workspace = true } aws-config = { version = "1", optional = true, features = ["behavior-version-latest"] } aws-sdk-s3 = { version = "1", optional = true } diff --git a/libsql-wal/src/replication/injector.rs b/libsql-wal/src/replication/injector.rs index 66710bbb22..c3642e196e 100644 --- a/libsql-wal/src/replication/injector.rs +++ b/libsql-wal/src/replication/injector.rs @@ -6,23 +6,23 @@ use crate::error::Result; use crate::io::Io; use crate::segment::Frame; use crate::shared_wal::SharedWal; -use crate::transaction::TxGuard; +use crate::transaction::TxGuardOwned; /// The injector takes frames and injects them in the wal. -pub struct Injector<'a, IO: Io> { +pub struct Injector { // The wal to which we are injecting wal: Arc>, buffer: Vec>, /// capacity of the frame buffer capacity: usize, - tx: TxGuard<'a, IO::File>, + tx: TxGuardOwned, max_tx_frame_no: u64, } -impl<'a, IO: Io> Injector<'a, IO> { +impl Injector { pub fn new( wal: Arc>, - tx: TxGuard<'a, IO::File>, + tx: TxGuardOwned, buffer_capacity: usize, ) -> Result { Ok(Self { @@ -34,7 +34,7 @@ impl<'a, IO: Io> Injector<'a, IO> { }) } - pub async fn insert_frame(&mut self, frame: Box) -> Result<()> { + pub async fn insert_frame(&mut self, frame: Box) -> Result> { let size_after = frame.size_after(); self.max_tx_frame_no = self.max_tx_frame_no.max(frame.header().frame_no()); self.buffer.push(frame); @@ -43,10 +43,10 @@ impl<'a, IO: Io> Injector<'a, IO> { self.flush(size_after).await?; } - Ok(()) + Ok(size_after.map(|_| self.max_tx_frame_no)) } - async fn flush(&mut self, size_after: Option) -> Result<()> { + pub async fn flush(&mut self, size_after: Option) -> Result<()> { let buffer = std::mem::take(&mut self.buffer); let current = self.wal.current.load(); let commit_data = size_after.map(|size| (size, self.max_tx_frame_no)); @@ -60,6 +60,11 @@ impl<'a, IO: Io> Injector<'a, IO> { Ok(()) } + + pub fn rollback(&mut self) { + self.buffer.clear(); + self.tx.reset(0); + } } #[cfg(test)] @@ -89,7 +94,7 @@ mod test { let mut tx = crate::transaction::Transaction::Read(replica_shared.begin_read(42)); replica_shared.upgrade(&mut tx).unwrap(); - let guard = tx.as_write_mut().unwrap().lock(); + let guard = tx.into_write().unwrap_or_else(|_| panic!()).into_lock_owned(); let mut injector = Injector::new(replica_shared.clone(), guard, 10).unwrap(); primary_conn.execute("create table test (x)", ()).unwrap(); diff --git a/libsql-wal/src/segment/current.rs b/libsql-wal/src/segment/current.rs index d8d720a145..bda6d5742a 100644 --- a/libsql-wal/src/segment/current.rs +++ b/libsql-wal/src/segment/current.rs @@ -22,7 +22,7 @@ use crate::io::file::FileExt; use crate::io::Inspect; use crate::segment::{checked_frame_offset, SegmentFlags}; use crate::segment::{frame_offset, page_offset, sealed::SealedSegment}; -use crate::transaction::{Transaction, TxGuard}; +use crate::transaction::{Transaction, TxGuard, TxGuardOwned}; use crate::{LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION}; use super::list::SegmentList; @@ -125,7 +125,7 @@ impl CurrentSegment { frames: Vec>, // (size_after, last_frame_no) commit_data: Option<(u32, u64)>, - tx: &mut TxGuard<'_, F>, + tx: &mut TxGuardOwned, ) -> Result>> where F: FileExt, diff --git a/libsql-wal/src/shared_wal.rs b/libsql-wal/src/shared_wal.rs index 09a2747c5a..461ad13e03 100644 --- a/libsql-wal/src/shared_wal.rs +++ b/libsql-wal/src/shared_wal.rs @@ -20,7 +20,7 @@ use libsql_sys::name::NamespaceName; #[derive(Default)] pub struct WalLock { - pub(crate) tx_id: Arc>>, + pub(crate) tx_id: Arc>>, /// When a writer is popped from the write queue, its write transaction may not be reading from the most recent /// snapshot. In this case, we return `SQLITE_BUSY_SNAPHSOT` to the caller. If no reads were performed /// with that transaction before upgrading, then the caller will call us back immediately after re-acquiring @@ -108,7 +108,7 @@ impl SharedWal { Some(id) if id == read_tx.conn_id => { tracing::trace!("taking reserved slot"); reserved.take(); - let lock = self.wal_lock.tx_id.lock(); + let lock = self.wal_lock.tx_id.lock_blocking(); let write_tx = self.acquire_write(read_tx, lock, reserved)?; *tx = Transaction::Write(write_tx); return Ok(()); @@ -117,7 +117,7 @@ impl SharedWal { } } - let lock = self.wal_lock.tx_id.lock(); + let lock = self.wal_lock.tx_id.lock_blocking(); match *lock { None if self.wal_lock.waiters.is_empty() => { let write_tx = @@ -144,7 +144,7 @@ impl SharedWal { fn acquire_write( &self, read_tx: &ReadTransaction, - mut tx_id_lock: MutexGuard>, + mut tx_id_lock: async_lock::MutexGuard>, mut reserved: MutexGuard>, ) -> Result> { // we read two fields in the header. There is no risk that a transaction commit in diff --git a/libsql-wal/src/transaction.rs b/libsql-wal/src/transaction.rs index 723cffeae1..f2cdd5be70 100644 --- a/libsql-wal/src/transaction.rs +++ b/libsql-wal/src/transaction.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use std::time::Instant; use libsql_sys::name::NamespaceName; -use parking_lot::{ArcMutexGuard, RawMutex}; use tokio::sync::mpsc; use crate::checkpointer::CheckpointMessage; @@ -31,6 +30,14 @@ impl Transaction { } } + pub fn into_write(self) -> Result, Self> { + if let Self::Write(v) = self { + Ok(v) + } else { + Err(self) + } + } + pub fn max_frame_no(&self) -> u64 { match self { Transaction::Write(w) => w.next_frame_no - 1, @@ -147,8 +154,27 @@ pub struct WriteTransaction { pub recompute_checksum: Option, } +pub struct TxGuardOwned { + _lock: async_lock::MutexGuardArc>, + inner: WriteTransaction, +} + +impl Deref for TxGuardOwned { + type Target = WriteTransaction; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for TxGuardOwned { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + pub struct TxGuard<'a, F> { - _lock: ArcMutexGuard>, + _lock: async_lock::MutexGuardArc>, inner: &'a mut WriteTransaction, } @@ -189,7 +215,7 @@ impl WriteTransaction { todo!("txn has already been commited"); } - let g = self.wal_lock.tx_id.lock_arc(); + let g = self.wal_lock.tx_id.lock_arc_blocking(); match *g { // we still hold the lock, we can proceed Some(id) if self.id == id => TxGuard { @@ -202,6 +228,25 @@ impl WriteTransaction { } } + pub fn into_lock_owned(self) -> TxGuardOwned { + if self.is_commited { + tracing::error!("transaction already commited"); + todo!("txn has already been commited"); + } + + let g = self.wal_lock.tx_id.lock_arc_blocking(); + match *g { + // we still hold the lock, we can proceed + Some(id) if self.id == id => TxGuardOwned { + _lock: g, + inner: self, + }, + // Somebody took the lock from us + Some(_) => todo!("lock stolen"), + None => todo!("not a transaction"), + } + } + pub fn reset(&mut self, savepoint_id: usize) { if savepoint_id >= self.savepoints.len() { unreachable!("savepoint doesn't exist"); @@ -231,7 +276,7 @@ impl WriteTransaction { let Self { wal_lock, read_tx, .. } = self; - let mut lock = wal_lock.tx_id.lock(); + let mut lock = wal_lock.tx_id.lock_blocking(); match *lock { Some(lock_id) if lock_id == read_tx.id => { lock.take(); diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index d0052f50d9..79cffb1c38 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -8,6 +8,7 @@ use futures::StreamExt as _; use libsql_replication::frame::{Frame, FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; @@ -116,7 +117,7 @@ impl RemoteClient { self.dirty = false; } let prefetch = self.session_token.is_some(); - let hello_req = self.make_request(HelloRequest::new()); + let hello_req = self.make_request(HelloRequest::new(WalFlavor::Sqlite)); let log_offset_req = self.make_request(LogOffset { next_offset: self.next_offset(), }); From 566664e9a0acc7bbd9544c49f5f573eaec6a0a7e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 15:14:01 +0200 Subject: [PATCH 057/121] fmt --- bottomless/src/replicator.rs | 3 +- libsql-replication/src/injector/error.rs | 2 +- libsql-replication/src/injector/mod.rs | 6 +-- .../src/injector/sqlite_injector/mod.rs | 38 +++++++++---------- libsql-replication/src/replicator.rs | 3 +- libsql-replication/src/rpc.rs | 2 +- libsql-server/src/rpc/replication_log.rs | 4 +- libsql-wal/src/replication/injector.rs | 5 ++- 8 files changed, 33 insertions(+), 30 deletions(-) diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index 26e190df66..4e92824778 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -1455,7 +1455,8 @@ impl Replicator { 4096, libsql_sys::connection::NO_AUTOCHECKPOINT, encryption_config, - ).await?; + ) + .await?; let prefix = format!("{}-{}/", self.db_name, generation); let mut page_buf = { let mut v = Vec::with_capacity(page_size); diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index 225960c4d1..ac8f1be711 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,4 +1,4 @@ -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub type BoxError = Box; #[derive(Debug, thiserror::Error)] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 20a81cfa01..39df68c777 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,16 +1,16 @@ use std::future::Future; -pub use sqlite_injector::SqliteInjector; pub use libsql_injector::LibsqlInjector; +pub use sqlite_injector::SqliteInjector; use crate::frame::{Frame, FrameNo}; -use error::Result; pub use error::Error; +use error::Result; mod error; -mod sqlite_injector; mod libsql_injector; +mod sqlite_injector; pub trait Injector { /// Inject a singular frame. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index 545fbe810d..2f4193e469 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -25,28 +25,23 @@ pub struct SqliteInjector { } impl Injector for SqliteInjector { - async fn inject_frame( - &mut self, - frame: Frame, - ) -> Result> { + async fn inject_frame(&mut self, frame: Frame) -> Result> { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().inject_frame(frame) - }).await.unwrap() + spawn_blocking(move || inner.lock().inject_frame(frame)) + .await + .unwrap() } async fn rollback(&mut self) { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().rollback() - }).await.unwrap(); + spawn_blocking(move || inner.lock().rollback()) + .await + .unwrap(); } async fn flush(&mut self) -> Result> { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().flush() - }).await.unwrap() + spawn_blocking(move || inner.lock().flush()).await.unwrap() } } @@ -56,13 +51,15 @@ impl SqliteInjector { capacity: usize, auto_checkpoint: u32, encryption_config: Option, - ) ->super::Result { + ) -> super::Result { let inner = spawn_blocking(move || { SqliteInjectorInner::new(path, capacity, auto_checkpoint, encryption_config) - }).await.unwrap()?; + }) + .await + .unwrap()?; Ok(Self { - inner: Arc::new(Mutex::new(inner)) + inner: Arc::new(Mutex::new(inner)), }) } } @@ -278,7 +275,8 @@ mod test { fn test_simple_inject_frames() { let temp = tempfile::tempdir().unwrap(); - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); let log = wal_log(); for frame in log { injector.inject_frame(frame).unwrap(); @@ -298,7 +296,8 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); let log = wal_log(); for frame in log { injector.inject_frame(frame).unwrap(); @@ -318,7 +317,8 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); let mut frames = wal_log(); assert!(injector diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index 31c766faad..ee75822676 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -168,7 +168,7 @@ where auto_checkpoint, encryption_config, ) - .await?; + .await?; Ok(Self::new(client, injector)) } @@ -179,7 +179,6 @@ where C: ReplicatorClient, I: Injector, { - pub fn new(client: C, injector: I) -> Self { Self { client, diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index a9b172db20..8e1165af65 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -51,7 +51,7 @@ pub mod replication { pub fn new(wal_flavor: WalFlavor) -> Self { Self { handshake_version: Some(1), - wal_flavor: Some(wal_flavor.into()) + wal_flavor: Some(wal_flavor.into()), } } } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 628cb4a01d..bf1840a5c6 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -7,8 +7,8 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; -use libsql_replication::rpc::replication::hello_request::WalFlavor; pub use libsql_replication::rpc::replication as rpc; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ Frame, Frames, HelloRequest, HelloResponse, LogOffset, NAMESPACE_DOESNT_EXIST, @@ -357,7 +357,7 @@ impl ReplicationLog for ReplicationLogService { } if let WalFlavor::Libsql = req.get_ref().wal_flavor() { - return Err(Status::invalid_argument("libsql wal not supported")) + return Err(Status::invalid_argument("libsql wal not supported")); } let (logger, config, version, _, _) = diff --git a/libsql-wal/src/replication/injector.rs b/libsql-wal/src/replication/injector.rs index c3642e196e..a922330102 100644 --- a/libsql-wal/src/replication/injector.rs +++ b/libsql-wal/src/replication/injector.rs @@ -94,7 +94,10 @@ mod test { let mut tx = crate::transaction::Transaction::Read(replica_shared.begin_read(42)); replica_shared.upgrade(&mut tx).unwrap(); - let guard = tx.into_write().unwrap_or_else(|_| panic!()).into_lock_owned(); + let guard = tx + .into_write() + .unwrap_or_else(|_| panic!()) + .into_lock_owned(); let mut injector = Injector::new(replica_shared.clone(), guard, 10).unwrap(); primary_conn.execute("create table test (x)", ()).unwrap(); From a2bdc805743f87f5c8cf2852c548578dbe496401 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 16:07:42 +0200 Subject: [PATCH 058/121] pass RpcFrame to client methods necessary to pass different underlying frames --- bottomless/src/replicator.rs | 7 ++- .../src/injector/libsql_injector.rs | 9 ++-- libsql-replication/src/injector/mod.rs | 5 +- .../src/injector/sqlite_injector/mod.rs | 5 +- libsql-replication/src/replicator.rs | 49 +++++++++++++------ .../src/replication/replicator_client.rs | 15 +++--- libsql/src/replication/local_client.rs | 14 ++++-- libsql/src/replication/remote_client.rs | 17 ++++--- 8 files changed, 79 insertions(+), 42 deletions(-) diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index 4e92824778..cd37a70165 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -18,6 +18,7 @@ use aws_sdk_s3::{Client, Config}; use bytes::{Buf, Bytes}; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; use libsql_replication::injector::Injector as _; +use libsql_replication::rpc::replication::Frame as RpcFrame; use libsql_sys::{Cipher, EncryptionConfig}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -1554,7 +1555,11 @@ impl Replicator { }, page_buf.as_slice(), ); - injector.inject_frame(frame_to_inject).await?; + let frame = RpcFrame { + data: frame_to_inject.bytes(), + timestamp: None, + }; + injector.inject_frame(frame).await?; applied_wal_frame = true; } } diff --git a/libsql-replication/src/injector/libsql_injector.rs b/libsql-replication/src/injector/libsql_injector.rs index 946d35e547..f867a29245 100644 --- a/libsql-replication/src/injector/libsql_injector.rs +++ b/libsql-replication/src/injector/libsql_injector.rs @@ -5,7 +5,8 @@ use libsql_wal::replication::injector::Injector; use libsql_wal::segment::Frame as WalFrame; use zerocopy::{AsBytes, FromZeroes}; -use crate::frame::{Frame, FrameNo}; +use crate::frame::FrameNo; +use crate::rpc::replication::Frame as RpcFrame; use super::error::{Error, Result}; @@ -14,15 +15,15 @@ pub struct LibsqlInjector { } impl super::Injector for LibsqlInjector { - async fn inject_frame(&mut self, frame: Frame) -> Result> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result> { // this is a bit annoying be we want to read the frame, and it has to be aligned, so we // must copy it... // FIXME: optimize this. let mut wal_frame = WalFrame::new_box_zeroed(); - if frame.bytes().len() != size_of::() { + if frame.data.len() != size_of::() { todo!("invalid frame"); } - wal_frame.as_bytes_mut().copy_from_slice(&frame.bytes()[..]); + wal_frame.as_bytes_mut().copy_from_slice(&frame.data[..]); Ok(self .injector .insert_frame(wal_frame) diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 39df68c777..3712458d2f 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,9 +1,10 @@ use std::future::Future; +use super::rpc::replication::Frame as RpcFrame; pub use libsql_injector::LibsqlInjector; pub use sqlite_injector::SqliteInjector; -use crate::frame::{Frame, FrameNo}; +use crate::frame::FrameNo; pub use error::Error; use error::Result; @@ -16,7 +17,7 @@ pub trait Injector { /// Inject a singular frame. fn inject_frame( &mut self, - frame: Frame, + frame: RpcFrame, ) -> impl Future>> + Send; /// Discard any uncommintted frames. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index 2f4193e469..f6ce2aa89f 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -7,6 +7,7 @@ use rusqlite::OpenFlags; use tokio::task::spawn_blocking; use crate::frame::{Frame, FrameNo}; +use crate::rpc::replication::Frame as RpcFrame; use self::injector_wal::{ InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, @@ -25,8 +26,10 @@ pub struct SqliteInjector { } impl Injector for SqliteInjector { - async fn inject_frame(&mut self, frame: Frame) -> Result> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result> { let inner = self.inner.clone(); + let frame = + Frame::try_from(&frame.data[..]).map_err(|e| Error::FatalInjectError(e.into()))?; spawn_blocking(move || inner.lock().inject_frame(frame)) .await .unwrap() diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index ee75822676..38cdbf6e7c 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -63,7 +63,7 @@ impl From for Error { #[async_trait::async_trait] pub trait ReplicatorClient { - type FrameStream: Stream> + Unpin + Send; + type FrameStream: Stream> + Unpin + Send; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error>; @@ -318,7 +318,7 @@ where } } - async fn inject_frame(&mut self, frame: Frame) -> Result<(), Error> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result<(), Error> { self.frames_synced += 1; match self.injector.inject_frame(frame).await? { @@ -360,6 +360,7 @@ mod test { use async_stream::stream; use crate::frame::{FrameBorrowed, FrameMut}; + use crate::rpc::replication::Frame as RpcFrame; use super::*; @@ -370,7 +371,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -414,7 +416,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -456,7 +459,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -500,7 +504,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -544,7 +549,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -586,7 +592,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -627,7 +634,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -672,7 +680,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -740,7 +749,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -752,15 +762,26 @@ mod test { let frames = self .frames .iter() + .map(|f| RpcFrame { + data: f.bytes(), + timestamp: None, + }) .take(2) - .cloned() .map(Ok) .chain(Some(Err(Error::Client("some client error".into())))) .collect::>(); Ok(Box::pin(tokio_stream::iter(frames))) } else { - let stream = tokio_stream::iter(self.frames.clone().into_iter().map(Ok)); - Ok(Box::pin(stream)) + let iter = self + .frames + .iter() + .map(|f| RpcFrame { + data: f.bytes(), + timestamp: None, + }) + .map(Ok) + .collect::>(); + Ok(Box::pin(tokio_stream::iter(iter))) } } /// Return a snapshot for the current replication index. Called after next_frame has returned a diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index d68c259dc9..89e465053b 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -4,16 +4,17 @@ use std::pin::Pin; use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::TryStreamExt; -use libsql_replication::frame::Frame; use libsql_replication::meta::WalIndexMeta; -use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::replicator::{Error, ReplicatorClient}; use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ - verify_session_token, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, SESSION_TOKEN_KEY, + verify_session_token, Frame as RpcFrame, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, + SESSION_TOKEN_KEY, }; use tokio::sync::watch; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::Stream; + use tonic::metadata::{AsciiMetadataValue, BinaryMetadataValue}; use tonic::transport::Channel; use tonic::{Code, Request, Status}; @@ -95,7 +96,7 @@ impl Client { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { @@ -169,7 +170,7 @@ impl ReplicatorClient for Client { None => REPLICATION_LATENCY_CACHE_MISS.increment(1), } }) - .map(map_frame_err); + .map_err(Into::into); Ok(Box::pin(stream)) } @@ -181,7 +182,7 @@ impl ReplicatorClient for Client { let req = self.make_request(offset); match self.client.snapshot(req).await { Ok(resp) => { - let stream = resp.into_inner().map(map_frame_err); + let stream = resp.into_inner().map_err(Into::into); Ok(Box::pin(stream)) } Err(e) if e.code() == Code::Unavailable => Err(Error::SnapshotPending), diff --git a/libsql/src/replication/local_client.rs b/libsql/src/replication/local_client.rs index 2d7b940c92..d3c713f530 100644 --- a/libsql/src/replication/local_client.rs +++ b/libsql/src/replication/local_client.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use futures::{StreamExt, TryStreamExt}; use libsql_replication::{ + rpc::replication::Frame as RpcFrame, frame::{Frame, FrameNo}, meta::WalIndexMeta, replicator::{Error, ReplicatorClient}, @@ -35,7 +36,7 @@ impl LocalClient { #[async_trait::async_trait] impl ReplicatorClient for LocalClient { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -46,7 +47,7 @@ impl ReplicatorClient for LocalClient { async fn next_frames(&mut self) -> Result { match self.frames.take() { Some(Frames::Vec(f)) => { - let iter = f.into_iter().map(Ok); + let iter = f.into_iter().map(|f| RpcFrame { data: f.bytes(), timestamp: None }).map(Ok); Ok(Box::pin(tokio_stream::iter(iter))) } Some(f @ Frames::Snapshot(_)) => { @@ -70,7 +71,8 @@ impl ReplicatorClient for LocalClient { if s.as_mut().peek().await.is_none() { next.header_mut().size_after = size_after.into(); } - yield Frame::from(next); + let frame = Frame::from(next); + yield RpcFrame { data: frame.bytes(), timestamp: None }; } }; @@ -95,8 +97,9 @@ impl ReplicatorClient for LocalClient { #[cfg(test)] mod test { - use libsql_replication::snapshot::SnapshotFile; + use libsql_replication::{frame::FrameHeader, snapshot::SnapshotFile}; use tempfile::tempdir; + use zerocopy::FromBytes; use super::*; @@ -111,7 +114,8 @@ mod test { let mut s = client.snapshot().await.unwrap(); assert!(matches!(s.next().await, Some(Ok(_)))); let last = s.next().await.unwrap().unwrap(); - assert_eq!(last.header().size_after.get(), 2); + let header: FrameHeader = FrameHeader::read_from_prefix(&last.data[..]).unwrap(); + assert_eq!(header.size_after.get(), 2); assert!(s.next().await.is_none()); } } diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index 79cffb1c38..26e537d18a 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -4,13 +4,13 @@ use std::pin::Pin; use std::time::{Duration, Instant}; use bytes::Bytes; -use futures::StreamExt as _; -use libsql_replication::frame::{Frame, FrameHeader, FrameNo}; +use futures::{StreamExt as _, TryStreamExt}; +use libsql_replication::frame::{FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; -use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::replicator::{Error, ReplicatorClient}; use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ - verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, + Frame as RpcFrame, verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; use tokio_stream::Stream; use tonic::metadata::AsciiMetadataValue; @@ -161,7 +161,7 @@ impl RemoteClient { let frames_iter = frames .into_iter() - .map(|f| Frame::try_from(&*f.data).map_err(|e| Error::Client(e.into()))); + .map(Ok); let stream = tokio_stream::iter(frames_iter); @@ -197,7 +197,7 @@ impl RemoteClient { .snapshot(req) .await? .into_inner() - .map(map_frame_err) + .map_err(|e| e.into()) .peekable(); { @@ -205,7 +205,8 @@ impl RemoteClient { // the first frame is the one with the highest frame_no in the snapshot if let Some(Ok(f)) = frames.peek().await { - self.last_received = Some(f.header().frame_no.get()); + let header: FrameHeader = FrameHeader::read_from_prefix(&f.data[..]).unwrap(); + self.last_received = Some(header.frame_no.get()); } } @@ -240,7 +241,7 @@ fn maybe_log( #[async_trait::async_trait] impl ReplicatorClient for RemoteClient { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { From e97026c82d4bfc13f2298d98c705a19cea6159be Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 16:33:10 +0200 Subject: [PATCH 059/121] feature gate libsql injector --- libsql-replication/Cargo.toml | 3 ++- libsql-replication/src/injector/mod.rs | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 2a03d362bf..068e23a652 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -12,7 +12,7 @@ license = "MIT" tonic = { version = "0.11", features = ["tls"] } prost = "0.12" libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } -libsql-wal = { path = "../libsql-wal/" } +libsql-wal = { path = "../libsql-wal/", optional = true } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } @@ -38,3 +38,4 @@ tonic-build = "0.11" [features] encryption = ["libsql-sys/encryption"] +libsql_wal = ["dep:libsql-wal"] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 3712458d2f..b139f07cc9 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,6 +1,7 @@ use std::future::Future; use super::rpc::replication::Frame as RpcFrame; +#[cfg(feature = "libsql_wal")] pub use libsql_injector::LibsqlInjector; pub use sqlite_injector::SqliteInjector; @@ -10,6 +11,7 @@ pub use error::Error; use error::Result; mod error; +#[cfg(feature = "libsql_wal")] mod libsql_injector; mod sqlite_injector; From d29ca7fabed4440e0eb89eab2b10765b85a38e84 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 10 Aug 2024 11:16:24 +0200 Subject: [PATCH 060/121] fix conflicts --- .../proto/replication_log.proto | 10 +++++----- libsql-replication/src/generated/wal_log.rs | 20 +++++++++---------- libsql-replication/src/rpc.rs | 4 +--- .../src/namespace/configurator/replica.rs | 4 +++- .../src/replication/replicator_client.rs | 6 ++++-- libsql-server/src/rpc/replication_log.rs | 18 +++++++++++------ libsql/src/replication/remote_client.rs | 6 ++++-- 7 files changed, 39 insertions(+), 29 deletions(-) diff --git a/libsql-replication/proto/replication_log.proto b/libsql-replication/proto/replication_log.proto index b3be419319..b358232705 100644 --- a/libsql-replication/proto/replication_log.proto +++ b/libsql-replication/proto/replication_log.proto @@ -5,18 +5,18 @@ import "metadata.proto"; message LogOffset { uint64 next_offset = 1; -} - -message HelloRequest { - optional uint64 handshake_version = 1; enum WalFlavor { Sqlite = 0; Libsql = 1; } - // the type of wal that the client is expecting + // the type of wal frames that the client is expecting optional WalFlavor wal_flavor = 2; } +message HelloRequest { + optional uint64 handshake_version = 1; +} + message HelloResponse { /// Uuid of the current generation string generation_id = 1; diff --git a/libsql-replication/src/generated/wal_log.rs b/libsql-replication/src/generated/wal_log.rs index 441881c4a7..a34d5e59dd 100644 --- a/libsql-replication/src/generated/wal_log.rs +++ b/libsql-replication/src/generated/wal_log.rs @@ -4,18 +4,12 @@ pub struct LogOffset { #[prost(uint64, tag = "1")] pub next_offset: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HelloRequest { - #[prost(uint64, optional, tag = "1")] - pub handshake_version: ::core::option::Option, - /// the type of wal that the client is expecting - #[prost(enumeration = "hello_request::WalFlavor", optional, tag = "2")] + /// the type of wal frames that the client is expecting + #[prost(enumeration = "log_offset::WalFlavor", optional, tag = "2")] pub wal_flavor: ::core::option::Option, } -/// Nested message and enum types in `HelloRequest`. -pub mod hello_request { +/// Nested message and enum types in `LogOffset`. +pub mod log_offset { #[derive( Clone, Copy, @@ -55,6 +49,12 @@ pub mod hello_request { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct HelloRequest { + #[prost(uint64, optional, tag = "1")] + pub handshake_version: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct HelloResponse { /// / Uuid of the current generation #[prost(string, tag = "1")] diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index 8e1165af65..a538bc4c28 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -26,7 +26,6 @@ pub mod replication { use uuid::Uuid; - use self::hello_request::WalFlavor; include!("generated/wal_log.rs"); pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO"; @@ -48,10 +47,9 @@ pub mod replication { } impl HelloRequest { - pub fn new(wal_flavor: WalFlavor) -> Self { + pub fn new() -> Self { Self { handshake_version: Some(1), - wal_flavor: Some(wal_flavor.into()), } } } diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 84ebadb897..7832d30ef8 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use futures::Future; use hyper::Uri; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use tokio::task::JoinSet; use tonic::transport::Channel; @@ -68,10 +69,11 @@ impl ConfigureNamespace for ReplicaConfigurator { &db_path, meta_store_handle.clone(), store.clone(), + WalFlavor::Sqlite, ) .await?; let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); - let mut replicator = libsql_replication::replicator::Replicator::new( + let mut replicator = libsql_replication::replicator::Replicator::new_sqlite( client, db_path.join("data"), DEFAULT_AUTO_CHECKPOINT, diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index 89e465053b..753baac996 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -6,7 +6,7 @@ use chrono::{DateTime, Utc}; use futures::TryStreamExt; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{Error, ReplicatorClient}; -use libsql_replication::rpc::replication::hello_request::WalFlavor; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ verify_session_token, Frame as RpcFrame, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, @@ -101,7 +101,7 @@ impl ReplicatorClient for Client { #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { tracing::debug!("Attempting to perform handshake with primary."); - let req = self.make_request(HelloRequest::new(self.wal_flavor)); + let req = self.make_request(HelloRequest::new()); let resp = self.client.hello(req).await?; let hello = resp.into_inner(); verify_session_token(&hello.session_token).map_err(Error::Client)?; @@ -143,6 +143,7 @@ impl ReplicatorClient for Client { async fn next_frames(&mut self) -> Result { let offset = LogOffset { next_offset: self.next_frame_no(), + wal_flavor: Some(self.wal_flavor.into()), }; let req = self.make_request(offset); let stream = self @@ -178,6 +179,7 @@ impl ReplicatorClient for Client { async fn snapshot(&mut self) -> Result { let offset = LogOffset { next_offset: self.next_frame_no(), + wal_flavor: Some(self.wal_flavor.into()), }; let req = self.make_request(offset); match self.client.snapshot(req).await { diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index bf1840a5c6..1ef306daf1 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -8,7 +8,7 @@ use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; pub use libsql_replication::rpc::replication as rpc; -use libsql_replication::rpc::replication::hello_request::WalFlavor; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ Frame, Frames, HelloRequest, HelloResponse, LogOffset, NAMESPACE_DOESNT_EXIST, @@ -260,6 +260,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } let namespace = super::extract_namespace(self.disable_namespaces, &req)?; self.authenticate(&req, namespace.clone()).await?; @@ -305,6 +308,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } let namespace = super::extract_namespace(self.disable_namespaces, &req)?; self.authenticate(&req, namespace.clone()).await?; @@ -355,11 +361,6 @@ impl ReplicationLog for ReplicationLogService { guard.insert((replica_addr, namespace.clone())); } } - - if let WalFlavor::Libsql = req.get_ref().wal_flavor() { - return Err(Status::invalid_argument("libsql wal not supported")); - } - let (logger, config, version, _, _) = self.logger_from_namespace(namespace, &req, false).await?; @@ -381,7 +382,12 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; let (logger, _, _, stats, _) = self.logger_from_namespace(namespace, &req, true).await?; diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index 26e537d18a..864392ddb5 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -8,7 +8,6 @@ use futures::{StreamExt as _, TryStreamExt}; use libsql_replication::frame::{FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{Error, ReplicatorClient}; -use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ Frame as RpcFrame, verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; @@ -117,9 +116,10 @@ impl RemoteClient { self.dirty = false; } let prefetch = self.session_token.is_some(); - let hello_req = self.make_request(HelloRequest::new(WalFlavor::Sqlite)); + let hello_req = self.make_request(HelloRequest::new()); let log_offset_req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); let mut client_clone = self.remote.clone(); let hello_fut = time(async { @@ -179,6 +179,7 @@ impl RemoteClient { None => { let req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); time(self.remote.replication.batch_log_entries(req)).await } @@ -190,6 +191,7 @@ impl RemoteClient { async fn do_snapshot(&mut self) -> Result<::FrameStream, Error> { let req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); let mut frames = self .remote From e7a24304d9378806561b9cfa2093c184a0d6fc4d Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 12:53:11 +0400 Subject: [PATCH 061/121] fix potential memory leak --- libsql-sqlite3/src/vectordiskann.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index ae832f4400..caaeee64b2 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -954,6 +954,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -965,10 +970,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; From 4c38e5f11e75c7debc8caf809daab28db0f5a1eb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 12:58:50 +0400 Subject: [PATCH 062/121] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 9 +++++---- libsql-ffi/bundled/src/sqlite3.c | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 6c60bccaab..b68e262935 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212717,6 +212717,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -212728,10 +212733,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 6c60bccaab..b68e262935 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212717,6 +212717,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -212728,10 +212733,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; From 90007425ba039965f2cc7cfd7d01f25b35bc1869 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 18:54:39 +0400 Subject: [PATCH 063/121] add simple pragma test --- libsql-sqlite3/test/libsql_vector_index.test | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7819ba5076..281943e7c2 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -30,12 +30,17 @@ set testprefix vector sqlite3_db_config_lookaside db 0 0 0 -do_execsql_test vector-integrity { - CREATE TABLE t_integrity( v FLOAT32(3) ); - CREATE INDEX t_integrity_idx ON t_integrity( libsql_vector_idx(v) ); - INSERT INTO t_integrity VALUES (vector('[1,2,3]')); +do_execsql_test vector-pragmas { + CREATE TABLE t_pragmas( v FLOAT32(3) ); + CREATE INDEX t_pragmas_idx ON t_pragmas( libsql_vector_idx(v) ); + INSERT INTO t_pragmas VALUES (vector('[1,2,3]')); PRAGMA integrity_check; -} {{row 1 missing from index t_integrity_idx} {wrong # of entries in index t_integrity_idx}} + PRAGMA index_list='t_pragmas'; +} { + {row 1 missing from index t_pragmas_idx} + {wrong # of entries in index t_pragmas_idx} + 0 t_pragmas_idx 0 c 0 +} do_execsql_test vector-typename { CREATE TABLE t_type_spaces( v FLOAT32 ( 3 ) ); From e8e5870dfa54fd6d64993928333dfa95ab604213 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 18:54:12 +0400 Subject: [PATCH 064/121] don't change idxType as sqlite rely on it pretty much --- libsql-sqlite3/src/build.c | 9 ++++----- libsql-sqlite3/src/parse.y | 3 --- libsql-sqlite3/src/sqliteInt.h | 9 +++------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/libsql-sqlite3/src/build.c b/libsql-sqlite3/src/build.c index 4396f2fae7..d6faa6b071 100644 --- a/libsql-sqlite3/src/build.c +++ b/libsql-sqlite3/src/build.c @@ -833,7 +833,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -4345,13 +4345,12 @@ void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -4399,7 +4398,7 @@ void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -4680,7 +4679,7 @@ void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists){ pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; diff --git a/libsql-sqlite3/src/parse.y b/libsql-sqlite3/src/parse.y index 41e08ad6d6..f866ec5d2c 100644 --- a/libsql-sqlite3/src/parse.y +++ b/libsql-sqlite3/src/parse.y @@ -1451,9 +1451,6 @@ paren_exprlist(A) ::= LP exprlist(X) RP. {A = X;} cmd ::= createkw(S) uniqueflag(U) INDEX ifnotexists(NE) nm(X) dbnm(D) indextype(T) ON nm(Y) LP sortlist(Z) RP where_opt(W). { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( T.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &X, &D, sqlite3SrcListAppend(pParse,0,&Y,0), Z, U, &S, W, SQLITE_SO_ASC, NE, idxType, T.pUsing); diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index e2fd32d3c4..0a9dd98d66 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -2799,7 +2799,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -2831,7 +2832,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -2840,10 +2840,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: From b1dbaa13bb66bf1aac0f078489de811507a1f630 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 19:22:44 +0400 Subject: [PATCH 065/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 21 +++++++------------ libsql-ffi/bundled/src/sqlite3.c | 21 +++++++------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 1b670120d5..8ceabfc713 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -19266,7 +19266,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -19298,7 +19299,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -19307,10 +19307,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: @@ -123188,7 +123185,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -126700,13 +126697,12 @@ SQLITE_PRIVATE void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -126754,7 +126750,7 @@ SQLITE_PRIVATE void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -127035,7 +127031,7 @@ SQLITE_PRIVATE void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; @@ -177910,9 +177906,6 @@ static YYACTIONTYPE yy_reduce( case 242: /* cmd ::= createkw uniqueflag INDEX ifnotexists nm dbnm indextype ON nm LP sortlist RP where_opt */ { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( yymsp[-6].minor.yy421.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &yymsp[-8].minor.yy0, &yymsp[-7].minor.yy0, sqlite3SrcListAppend(pParse,0,&yymsp[-4].minor.yy0,0), yymsp[-2].minor.yy402, yymsp[-11].minor.yy502, &yymsp[-12].minor.yy0, yymsp[0].minor.yy590, SQLITE_SO_ASC, yymsp[-9].minor.yy502, idxType, yymsp[-6].minor.yy421.pUsing); diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 1b670120d5..8ceabfc713 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -19266,7 +19266,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -19298,7 +19299,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -19307,10 +19307,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: @@ -123188,7 +123185,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -126700,13 +126697,12 @@ SQLITE_PRIVATE void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -126754,7 +126750,7 @@ SQLITE_PRIVATE void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -127035,7 +127031,7 @@ SQLITE_PRIVATE void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; @@ -177910,9 +177906,6 @@ static YYACTIONTYPE yy_reduce( case 242: /* cmd ::= createkw uniqueflag INDEX ifnotexists nm dbnm indextype ON nm LP sortlist RP where_opt */ { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( yymsp[-6].minor.yy421.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &yymsp[-8].minor.yy0, &yymsp[-7].minor.yy0, sqlite3SrcListAppend(pParse,0,&yymsp[-4].minor.yy0,0), yymsp[-2].minor.yy402, yymsp[-11].minor.yy502, &yymsp[-12].minor.yy0, yymsp[0].minor.yy590, SQLITE_SO_ASC, yymsp[-9].minor.yy502, idxType, yymsp[-6].minor.yy421.pUsing); From cea87253fbfdd5e6cdf5cb245384b5ebbb4e1044 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:43:51 +0400 Subject: [PATCH 066/121] fix DELETE from vector index as there can be no row due to the NULL value of the vector --- libsql-sqlite3/src/vectordiskann.c | 6 +++++- libsql-sqlite3/test/libsql_vector_index.test | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 1f9973d38c..95276f8887 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1633,7 +1633,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 281943e7c2..dd1f53d0df 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -135,9 +135,19 @@ do_execsql_test vector-null { CREATE INDEX t_null_idx ON t_null( libsql_vector_idx(v) ); INSERT INTO t_null VALUES(vector('[1,2,3]')); INSERT INTO t_null VALUES(NULL); - INSERT INTO t_null VALUES(vector('[2,3,4]')); + INSERT INTO t_null VALUES(vector('[3,4,5]')); SELECT * FROM vector_top_k('t_null_idx', '[1,2,3]', 2); -} {1 3} + UPDATE t_null SET v = vector('[2,3,4]') WHERE rowid = 2; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + UPDATE t_null SET v = NULL WHERE rowid = 3; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + UPDATE t_null SET v = NULL; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); +} { + 1 3 + 2 3 1 + 2 1 +} do_execsql_test vector-sql { CREATE TABLE t_sql( v FLOAT32(3)); From ac60c100d25a7e80ef8aa8a5df775c126217e0f8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:46:05 +0400 Subject: [PATCH 067/121] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 6 +++++- libsql-ffi/bundled/src/sqlite3.c | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 8ceabfc713..887d3322b9 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -213389,7 +213389,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 8ceabfc713..887d3322b9 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -213389,7 +213389,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } From 405c71073636dea35feed1cf79d150ee6a835493 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:56:04 +0400 Subject: [PATCH 068/121] refine comment --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 3 ++- libsql-ffi/bundled/src/sqlite3.c | 3 ++- libsql-sqlite3/src/vectordiskann.c | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 887d3322b9..1fd0252f0d 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -213390,7 +213390,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 887d3322b9..1fd0252f0d 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -213390,7 +213390,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95276f8887..a6c279b259 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1634,7 +1634,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ From c135391f60d30c671d164119216ad8545d3768eb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 12:16:52 +0400 Subject: [PATCH 069/121] fix unstable test --- libsql-sqlite3/test/libsql_vector_index.test | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index dd1f53d0df..c872a20ae0 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -131,21 +131,21 @@ do_execsql_test vector-empty { do_execsql_test vector-null { - CREATE TABLE t_null( v FLOAT32(3)); + CREATE TABLE t_null( v FLOAT32(2)); CREATE INDEX t_null_idx ON t_null( libsql_vector_idx(v) ); - INSERT INTO t_null VALUES(vector('[1,2,3]')); + INSERT INTO t_null VALUES(vector('[1,-1]')); INSERT INTO t_null VALUES(NULL); - INSERT INTO t_null VALUES(vector('[3,4,5]')); - SELECT * FROM vector_top_k('t_null_idx', '[1,2,3]', 2); - UPDATE t_null SET v = vector('[2,3,4]') WHERE rowid = 2; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + INSERT INTO t_null VALUES(vector('[-2,1]')); + SELECT * FROM vector_top_k('t_null_idx', '[1,1]', 2); + UPDATE t_null SET v = vector('[1,1]') WHERE rowid = 2; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); UPDATE t_null SET v = NULL WHERE rowid = 3; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); UPDATE t_null SET v = NULL; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); } { 1 3 - 2 3 1 + 2 1 3 2 1 } From 0f4531a10e7cd2ff5cdb069ff6d07fa3721754c7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 13:45:38 +0400 Subject: [PATCH 070/121] add support for vector indices over f64 embeddings --- libsql-sqlite3/src/vectorIndex.c | 4 ++-- libsql-sqlite3/src/vectordiskann.c | 10 +++++----- libsql-sqlite3/test/libsql_vector_index.test | 9 ++++++++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 11a9585ed2..35c7b6908b 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -967,8 +967,8 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); + if( type != VECTOR_TYPE_FLOAT32 && type != VECTOR_TYPE_FLOAT64 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): unsupported vector type: only FLOAT32/FLOAT64 are available for indexing"); rc = SQLITE_ERROR; goto out; } diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 1f9973d38c..c28cf075e9 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1428,12 +1428,12 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): k must be a non-negative integer"); return SQLITE_ERROR; } - if( pIndex->nVectorDims != pVector->dims ){ + if( pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(search): dimensions are different: %d != %d", pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); + if( pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): vector type differs from column type: %d != %d", pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } @@ -1498,8 +1498,8 @@ int diskAnnInsert( *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVectorInRow->pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(insert): only f32 vectors are supported"); + if( pVectorInRow->pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): vector type differs from column type: %d != %d", pVectorInRow->pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 281943e7c2..98d11208fa 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -290,6 +290,13 @@ do_execsql_test vector-all-params { SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} +do_execsql_test vector-f64-index { + CREATE TABLE t_f64 ( emb FLOAT64(2) ); + CREATE INDEX t_f64_idx ON t_f64(libsql_vector_idx(emb)); + INSERT INTO t_f64 VALUES (vector64('[1,2]')), (vector64('[3,4]')); + SELECT * FROM vector_top_k('t_f64_idx', vector64('[1,2]'), 2); +} {1 2} + do_execsql_test vector-partial { CREATE TABLE t_partial( name TEXT, type INT, v FLOAT32(3)); INSERT INTO t_partial VALUES ( 'a', 0, vector('[1,2,3]') ); @@ -368,7 +375,7 @@ do_test vector-errors { {vector index: unsupported for tables without ROWID and composite primary key} {vector index(insert): dimensions are different: 1 != 4} {vector index(insert): dimensions are different: 5 != 4} - {vector index(insert): only f32 vectors are supported} + {vector index(insert): vector type differs from column type: 2 != 1} {vector index(search): dimensions are different: 2 != 4} {vector index(insert): dimensions are different: 1 != 3} }] From 626ff8c62c48786afcc5535b94b1cbbf7c73999d Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 15:12:18 +0400 Subject: [PATCH 071/121] prepare conversion function for mooore types --- libsql-sqlite3/src/vector.c | 115 ++++++++++++++++++++++++++++++++---- 1 file changed, 105 insertions(+), 10 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index c622d977e1..5903e62294 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -481,26 +481,121 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } -void vectorConvert(const Vector *pFrom, Vector *pTo){ +static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ int i; - u8 *bitData; - float *floatData; + float *src; + + u8 *dst1Bit; + double *dstF64; assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT32 ); - if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ - floatData = pFrom->data; - bitData = pTo->data; + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ - bitData[i / 8] = 0; + dst1Bit[i / 8] = 0; } for(i = 0; i < pFrom->dims; i++){ - if( floatData[i] > 0 ){ - bitData[i / 8] |= (1 << (i & 7)); + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); } } }else{ - assert(0); + assert( 0 ); + } +} + +static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ + int i; + double *src; + + u8 *dst1Bit; + float *dstF32; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT64 ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); + } +} + +static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + + float *dstF32; + double *dstF64; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_1BIT ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF32[i] = +1; + }else{ + dstF32[i] = -1; + } + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF64[i] = +1; + }else{ + dstF64[i] = -1; + } + } + }else{ + assert( 0 ); + } +} + +void vectorConvert(const Vector *pFrom, Vector *pTo){ + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == pTo->type ){ + memcpy(pTo->data, pFrom->data, vectorDataSize(pFrom->type, pFrom->dims)); + return; + } + + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + vectorConvertFromF32(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + vectorConvertFromF64(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + vectorConvertFrom1Bit(pFrom, pTo); + }else{ + assert( 0 ); } } From 4a57f02b78764b056cb5cc95a07bd28ed4cc0ad9 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 15:20:52 +0400 Subject: [PATCH 072/121] add simple conversion test --- libsql-sqlite3/src/vector.c | 36 ++++++++++++++++++-------- libsql-sqlite3/test/libsql_vector.test | 17 ++++++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 5903e62294..b1289e2e85 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -610,31 +610,45 @@ static void vectorFuncHintedType( sqlite3_context *context, int argc, sqlite3_value **argv, - int typeHint + int targetType ){ char *pzErrMsg = NULL; - Vector *pVector; + Vector *pVector = NULL, *pTarget = NULL; int type, dims; if( argc < 1 ){ - return; + goto out; } - if( detectVectorParameters(argv[0], typeHint, &type, &dims, &pzErrMsg) != 0 ){ + if( detectVectorParameters(argv[0], targetType, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free_vec; + goto out; + } + if( type == targetType ){ + vectorSerializeWithType(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, targetType, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorSerializeWithType(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorSerializeWithType(context, pVector); -out_free_vec: - vectorFree(pVector); } static void vector32Func( diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index cf91a7fa18..e541e9e977 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -67,6 +67,23 @@ do_execsql_test vector-1-func-valid { {0.200000002980232} } +do_execsql_test vector-1-conversion { + SELECT hex(vector32('[]')); + SELECT hex(vector64(vector32('[]'))); + + SELECT hex(vector32(vector32('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); + SELECT hex(vector32(vector64('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); + SELECT hex(vector64(vector32('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); + SELECT hex(vector64(vector64('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); +} { + {} + 02 + BD378635000000000000807FF9021550FFE6DB2E000000000000C03F + BD378635000000000000807FF9021550FFE6DB2E000000000000C03F + 000000A0F7C6B03E0000000000000000000000000000F07F000000205FA00242000000E0DF7CDB3D0000000000000000000000000000F83F02 + 8DEDB5A0F7C6B03E30058EE42EFF2B2B7DC39425AD49B254000000205FA00242BBBDD7D9DF7CDB3D0000000000000000000000000000F83F02 +} + proc error_messages {sql} { set ret "" set stmt [sqlite3_prepare db $sql -1 dummy] From 4b562ed9055b4b108d81f102fa65489807831519 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 15:22:47 +0400 Subject: [PATCH 073/121] small refactoring --- libsql-sqlite3/src/vectorIndex.c | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 35c7b6908b..4ee9dbe313 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -24,6 +24,7 @@ ** ** libSQL vector search. */ +#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "sqlite3.h" #include "vdbeInt.h" @@ -373,14 +374,14 @@ void vectorOutRowsFree(sqlite3 *db, VectorOutRows *pRows) { */ struct VectorColumnType { const char *zName; - int nBits; + int type; }; static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { - { "FLOAT32", 32 }, - { "FLOAT64", 64 }, - { "F32_BLOB", 32 }, - { "F64_BLOB", 64 } + { "FLOAT32", VECTOR_TYPE_FLOAT32 }, + { "FLOAT64", VECTOR_TYPE_FLOAT64 }, + { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, + { "F64_BLOB", VECTOR_TYPE_FLOAT64 } }; /* @@ -569,14 +570,7 @@ int vectorIdxParseColumnType(const char *zType, int *pType, int *pDims, const ch } *pDims = dimensions; - if( VECTOR_COLUMN_TYPES[i].nBits == 32 ) { - *pType = VECTOR_TYPE_FLOAT32; - } else if( VECTOR_COLUMN_TYPES[i].nBits == 64 ) { - *pType = VECTOR_TYPE_FLOAT64; - } else { - *pErrMsg = "unsupported vector type"; - return -1; - } + *pType = VECTOR_COLUMN_TYPES[i].type; return 0; } *pErrMsg = "unexpected vector column type"; From 181464f14eff80b12c95864d0259749f72ee884e Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 16:10:52 +0400 Subject: [PATCH 074/121] add support for 1bit vector functions --- libsql-sqlite3/src/vector.c | 186 +++++++++++++++++++++++++------ libsql-sqlite3/src/vector1bit.c | 14 +++ libsql-sqlite3/src/vectorIndex.c | 10 +- libsql-sqlite3/src/vectorInt.h | 7 +- 4 files changed, 176 insertions(+), 41 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index b1289e2e85..6d86619409 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -42,7 +42,6 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: - assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -253,33 +252,84 @@ static int vectorParseSqliteText( return -1; } +static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ + int nLeftoverBits; + + if( nBlobSize % 2 == 0 ){ + *pType = VECTOR_TYPE_FLOAT32; + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + return SQLITE_OK; + } + *pType = pBlob[nBlobSize - 1]; + nBlobSize--; + + if( *pType == VECTOR_TYPE_FLOAT32 ){ + if( nBlobSize % 4 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_FLOAT64 ){ + if( nBlobSize % 8 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(double); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_1BIT ){ + if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nLeftoverBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nLeftoverBits; + *pDataSize = (*pDims + 7) / 8; + }else{ + *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + return SQLITE_ERROR; + } + return SQLITE_OK; +} + int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ const unsigned char *pBlob; - size_t nBlobSize; + size_t nBlobSize, nDataSize; + int type, dims; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 == 1 ){ - nBlobSize--; + if( vectorParseMeta(pBlob, nBlobSize, &type, &dims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } - if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf( + "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + pVector->type, + pVector->dims, + nDataSize, + vectorDataSize(pVector->type, pVector->dims) + ); return SQLITE_ERROR; } switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF32DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; case VECTOR_TYPE_FLOAT64: - vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; + case VECTOR_TYPE_1BIT: + vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: assert(0); @@ -298,15 +348,22 @@ int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char if( nBlobSize % 2 != 0 ){ // we have trailing byte with explicit type definition *pType = pBlob[nBlobSize - 1]; + nBlobSize--; } else { // else, fallback to FLOAT32 *pType = VECTOR_TYPE_FLOAT32; } if( *pType == VECTOR_TYPE_FLOAT32 ){ *pDims = nBlobSize / sizeof(float); - } else if( *pType == VECTOR_TYPE_FLOAT64 ){ + }else if( *pType == VECTOR_TYPE_FLOAT64 ){ *pDims = nBlobSize / sizeof(double); - } else{ + }else if( *pType == VECTOR_TYPE_1BIT ){ + if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector: malformed 1bit float: blob size must has even size (without last byte): size=%d", nBlobSize); + return -1; + } + *pDims = nBlobSize * 8 - pBlob[nBlobSize - 1]; + }else{ *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: got %d, expected %d or %d", *pType, VECTOR_TYPE_FLOAT32, VECTOR_TYPE_FLOAT64); return -1; } @@ -411,21 +468,55 @@ void vectorMarshalToText( } } -void vectorSerializeWithType( +static int vectorMetaSize(VectorType type, VectorDims dims){ + int nMetaSize = 0; + int nDataSize; + if( type == VECTOR_TYPE_FLOAT32 ){ + return 0; + }else if( type == VECTOR_TYPE_FLOAT64 ){ + return 1; + }else if( type == VECTOR_TYPE_1BIT ){ + nDataSize = vectorDataSize(type, dims); + nMetaSize++; // one byte which specify amount of leftover bits + if( nDataSize % 2 == 0 ){ + nMetaSize++; // pad "leftover-bits" byte to the even length + } + nMetaSize++; // one byte for vector type + return nMetaSize; + }else{ + assert( 0 ); + } +} + +static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){ + if( pVector->type == VECTOR_TYPE_FLOAT32 ){ + // no meta for f32 type as this is "default" vector type + }else if( pVector->type == VECTOR_TYPE_FLOAT64 ){ + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 1 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; + }else if( pVector->type == VECTOR_TYPE_1BIT ){ + assert( nBlobSize % 2 == 1 ); + assert( nBlobSize >= 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + }else{ + assert( 0 ); + } +} + +void vectorSerializeWithMeta( sqlite3_context *context, const Vector *pVector ){ unsigned char *pBlob; - size_t nBlobSize, nDataSize; + size_t nBlobSize, nDataSize, nMetaSize; assert( pVector->dims <= MAX_VECTOR_SZ ); nDataSize = vectorDataSize(pVector->type, pVector->dims); - nBlobSize = nDataSize; - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); - } - + nMetaSize = vectorMetaSize(pVector->type, pVector->dims); + nBlobSize = nDataSize + nMetaSize; if( nBlobSize == 0 ){ sqlite3_result_zeroblob(context, 0); return; @@ -437,10 +528,6 @@ void vectorSerializeWithType( return; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - pBlob[nBlobSize - 1] = pVector->type; - } - switch (pVector->type) { case VECTOR_TYPE_FLOAT32: vectorF32SerializeToBlob(pVector, pBlob, nDataSize); @@ -448,9 +535,13 @@ void vectorSerializeWithType( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; + case VECTOR_TYPE_1BIT: + vector1BitSerializeToBlob(pVector, pBlob, nDataSize); + break; default: assert(0); } + vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } @@ -614,11 +705,15 @@ static void vectorFuncHintedType( ){ char *pzErrMsg = NULL; Vector *pVector = NULL, *pTarget = NULL; - int type, dims; + int type, dims, typeHint = VECTOR_TYPE_FLOAT32; if( argc < 1 ){ goto out; } - if( detectVectorParameters(argv[0], targetType, &type, &dims, &pzErrMsg) != 0 ){ + // simplification in order to support only parsing from text to f32 and f64 vectors + if( targetType == VECTOR_TYPE_FLOAT64 ){ + typeHint = targetType; + } + if( detectVectorParameters(argv[0], typeHint, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out; @@ -633,14 +728,14 @@ static void vectorFuncHintedType( goto out; } if( type == targetType ){ - vectorSerializeWithType(context, pVector); + vectorSerializeWithMeta(context, pVector); }else{ pTarget = vectorContextAlloc(context, targetType, dims); if( pTarget == NULL ){ goto out; } vectorConvert(pVector, pTarget); - vectorSerializeWithType(context, pTarget); + vectorSerializeWithMeta(context, pTarget); } out: if( pVector != NULL ){ @@ -666,6 +761,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector1BitFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); +} + /* ** Implementation of vector_extract(X) function. */ @@ -675,30 +778,44 @@ static void vectorExtractFunc( sqlite3_value **argv ){ char *pzErrMsg = NULL; - Vector *pVector; + Vector *pVector = NULL, *pTarget = NULL; unsigned i; int type, dims; if( argc < 1 ){ - return; + goto out; } if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free; + goto out; + } + if( pVector->type == VECTOR_TYPE_FLOAT32 || pVector->type == VECTOR_TYPE_FLOAT64 ){ + vectorMarshalToText(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, VECTOR_TYPE_FLOAT32, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorMarshalToText(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorMarshalToText(context, pVector); -out_free: - vectorFree(pVector); } /* @@ -782,6 +899,7 @@ void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector, 1, 0, 0, vector32Func), FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), + FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index f4fd5f9100..b80c166522 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -124,4 +124,18 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ return diff; } +void vector1BitDeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + memcpy(elems, pBlob, (pVector->dims + 7) / 8); +} + #endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 4ee9dbe313..b8eb17262a 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -378,10 +378,12 @@ struct VectorColumnType { }; static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { - { "FLOAT32", VECTOR_TYPE_FLOAT32 }, - { "FLOAT64", VECTOR_TYPE_FLOAT64 }, - { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, - { "F64_BLOB", VECTOR_TYPE_FLOAT64 } + { "FLOAT32", VECTOR_TYPE_FLOAT32 }, + { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, + { "FLOAT64", VECTOR_TYPE_FLOAT64 }, + { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, + { "FLOAT1BIT", VECTOR_TYPE_1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, }; /* diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index efe8f3cf38..e703585224 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -92,15 +92,16 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerializeWithType(sqlite3_context *, const Vector *); +void vectorSerializeWithMeta(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); -void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); From 307139fda9bd0140ea62709f3a3e97d6485d98cd Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 16:15:24 +0400 Subject: [PATCH 075/121] add conversion tests --- libsql-sqlite3/test/libsql_vector.test | 33 +++++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index e541e9e977..c6480be9a3 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -71,17 +71,32 @@ do_execsql_test vector-1-conversion { SELECT hex(vector32('[]')); SELECT hex(vector64(vector32('[]'))); - SELECT hex(vector32(vector32('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); - SELECT hex(vector32(vector64('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); - SELECT hex(vector64(vector32('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); - SELECT hex(vector64(vector64('[0.000001,1e-100,1e100,1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector32(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector32(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector32(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector32(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + + SELECT vector_extract(vector64(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector64(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector64(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector64(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + + SELECT vector_extract(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + SELECT vector_extract(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); } { {} - 02 - BD378635000000000000807FF9021550FFE6DB2E000000000000C03F - BD378635000000000000807FF9021550FFE6DB2E000000000000C03F - 000000A0F7C6B03E0000000000000000000000000000F07F000000205FA00242000000E0DF7CDB3D0000000000000000000000000000F83F02 - 8DEDB5A0F7C6B03E30058EE42EFF2B2B7DC39425AD49B254000000205FA00242BBBDD7D9DF7CDB3D0000000000000000000000000000F83F02 + 02 + + {[-1,-1,1,-1,1,-1,1]} 000080BF000080BF0000803F000080BF0000803F000080BF0000803F + {[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} BD3786B5000000000000807FF90215D0FFE6DB2E000000000000C03F + {[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} BD3786B5000000000000807FF90215D0FFE6DB2E000000000000C03F + + {[-1,-1,1,-1,1,-1,1]} 000000000000F0BF000000000000F0BF000000000000F03F000000000000F0BF000000000000F03F000000000000F0BF000000000000F03F02 + {[-1e-06,0,Inf,-1e+10,1e-10,0,1.5]} 000000A0F7C6B0BE0000000000000000000000000000F07F000000205FA002C2000000E0DF7CDB3D0000000000000000000000000000F83F02 + {[-1e-06,1e-100,1e+100,-1e+10,1e-10,0,1.5]} 8DEDB5A0F7C6B0BE30058EE42EFF2B2B7DC39425AD49B254000000205FA002C2BBBDD7D9DF7CDB3D0000000000000000000000000000F83F02 + + {[-1,-1,1,-1,1,-1,1]} 540903 + {[-1,-1,1,-1,1,-1,1]} 540903 + {[-1,1,1,-1,1,-1,1]} 560903 } proc error_messages {sql} { From 1216f17e0ea911ec62d3797870a4d734f09c5f65 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:19:34 +0400 Subject: [PATCH 076/121] remove unused function --- libsql-sqlite3/src/vector.c | 41 +++++------------------------- libsql-sqlite3/src/vectorIndex.c | 8 ++---- libsql-sqlite3/src/vectorInt.h | 2 -- libsql-sqlite3/src/vectorfloat32.c | 5 ---- libsql-sqlite3/src/vectorfloat64.c | 5 ---- 5 files changed, 9 insertions(+), 52 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 6d86619409..b5c43901e8 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -339,39 +339,21 @@ int vectorParseSqliteBlobWithType( int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char **pzErrMsg) { const u8 *pBlob; - int nBlobSize; + size_t nBlobSize, nDataSize; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 != 0 ){ - // we have trailing byte with explicit type definition - *pType = pBlob[nBlobSize - 1]; - nBlobSize--; - } else { - // else, fallback to FLOAT32 - *pType = VECTOR_TYPE_FLOAT32; - } - if( *pType == VECTOR_TYPE_FLOAT32 ){ - *pDims = nBlobSize / sizeof(float); - }else if( *pType == VECTOR_TYPE_FLOAT64 ){ - *pDims = nBlobSize / sizeof(double); - }else if( *pType == VECTOR_TYPE_1BIT ){ - if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: malformed 1bit float: blob size must has even size (without last byte): size=%d", nBlobSize); - return -1; - } - *pDims = nBlobSize * 8 - pBlob[nBlobSize - 1]; - }else{ - *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: got %d, expected %d or %d", *pType, VECTOR_TYPE_FLOAT32, VECTOR_TYPE_FLOAT64); - return -1; + + if( vectorParseMeta(pBlob, nBlobSize, pType, pDims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } if( *pDims > MAX_VECTOR_SZ ){ *pzErrMsg = sqlite3_mprintf("vector: max size exceeded: %d > %d", *pDims, MAX_VECTOR_SZ); - return -1; + return SQLITE_ERROR; } - return 0; + return SQLITE_OK; } int detectTextVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pDims, char **pzErrMsg) { @@ -560,16 +542,7 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32InitFromBlob(pVector, pBlob, nBlobSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64InitFromBlob(pVector, pBlob, nBlobSize); - break; - default: - assert(0); - } + pVector->data = (void*)pBlob; } static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index b8eb17262a..6413bf0822 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -883,7 +883,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: %s: %s", pzErrMsg, zEmbeddingColumnTypeName); return CREATE_FAIL; } - // schema is locked while db is initializing and we need to just proceed here if( db->init.busy == 1 ){ return CREATE_OK; @@ -963,11 +962,8 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( type != VECTOR_TYPE_FLOAT32 && type != VECTOR_TYPE_FLOAT64 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): unsupported vector type: only FLOAT32/FLOAT64 are available for indexing"); - rc = SQLITE_ERROR; - goto out; - } + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + pVector = vectorAlloc(type, dims); if( pVector == NULL ){ rc = SQLITE_NOMEM_BKPT; diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index e703585224..350a9ae9bd 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -105,8 +105,6 @@ void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); void vectorConvert(const Vector *, Vector *); diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index d53d10d593..9749a84835 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -168,11 +168,6 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(float); - pVector->data = (void*)pBlob; -} - void vectorF32DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index 885306c8c6..9f854793ab 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -175,11 +175,6 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(double); - pVector->data = (void*)pBlob; -} - void vectorF64DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, From 1a704d8ba48fab105028d834decc7bf449fae061 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:19:40 +0400 Subject: [PATCH 077/121] fix test --- libsql-sqlite3/test/libsql_vector.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index c6480be9a3..5118cae6a0 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -131,7 +131,7 @@ do_test vector-1-func-errors { {vector: invalid float at position 0: '[1'} {vector: invalid float at position 2: '1.1.1'} {vector: must end with ']'} - {vector: unexpected binary type: got 0, expected 1 or 2} + {invalid vector: unexpected type: 0} {vector_distance_cos: vectors must have the same length: 3 != 2} {vector_distance_cos: vectors must have the same type: 1 != 2} }] From e0035d6e17645681d71bdd8dc394cfe8fd69ec02 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:25:29 +0400 Subject: [PATCH 078/121] add more tests --- libsql-sqlite3/test/libsql_vector.test | 6 ++++++ libsql-sqlite3/test/libsql_vector_index.test | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index 5118cae6a0..7afb0f8bb0 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -50,6 +50,9 @@ do_execsql_test vector-1-func-valid { SELECT vector_distance_cos('[1,1]', '[-1,-1]'); SELECT vector_distance_cos('[1,1]', '[-1,1]'); SELECT vector_distance_cos('[1,2]', '[2,1]'); + SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[-5,4]')); + SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,4]')); + SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,-2]')); } { {[]} {[]} @@ -65,6 +68,9 @@ do_execsql_test vector-1-func-valid { {2.0} {1.0} {0.200000002980232} + {2.0} + {1.0} + {0.0} } do_execsql_test vector-1-conversion { diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 98d11208fa..0756566914 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -327,6 +327,15 @@ do_execsql_test vector-partial { 2 3 5 6 8 9 } +do_execsql_test vector-1bit-table { + CREATE TABLE t_1bit_table( v FLOAT1BIT(4) ); + INSERT INTO t_1bit_table VALUES ( vector1bit('[1,-1,1,-1]') ); + CREATE INDEX t_1bit_table_idx ON t_1bit_table( libsql_vector_idx(v) ); + INSERT INTO t_1bit_table VALUES ( vector1bit('[-1,1,1,-1]') ); + INSERT INTO t_1bit_table VALUES ( vector1bit('[1,-1,-1,1]') ); + SELECT * FROM vector_top_k('t_1bit_table_idx', vector1bit('[10,-10,-20,20]'), 4); +} {3 1 2} + proc error_messages {sql} { set ret "" catch { From 51fc1daa1e16157d42dae26808942cf6d841b38b Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:27:15 +0400 Subject: [PATCH 079/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 423 +++++++++++++----- libsql-ffi/bundled/src/sqlite3.c | 423 +++++++++++++----- 2 files changed, 618 insertions(+), 228 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 8ceabfc713..3568559303 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85311,20 +85311,19 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerializeWithType(sqlite3_context *, const Vector *); +void vectorSerializeWithMeta(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); -void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); void vectorConvert(const Vector *, Vector *); @@ -210981,7 +210980,6 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: - assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211192,33 +211190,84 @@ static int vectorParseSqliteText( return -1; } +static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ + int nLeftoverBits; + + if( nBlobSize % 2 == 0 ){ + *pType = VECTOR_TYPE_FLOAT32; + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + return SQLITE_OK; + } + *pType = pBlob[nBlobSize - 1]; + nBlobSize--; + + if( *pType == VECTOR_TYPE_FLOAT32 ){ + if( nBlobSize % 4 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_FLOAT64 ){ + if( nBlobSize % 8 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(double); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_1BIT ){ + if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nLeftoverBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nLeftoverBits; + *pDataSize = (*pDims + 7) / 8; + }else{ + *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + return SQLITE_ERROR; + } + return SQLITE_OK; +} + int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ const unsigned char *pBlob; - size_t nBlobSize; + size_t nBlobSize, nDataSize; + int type, dims; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 == 1 ){ - nBlobSize--; + if( vectorParseMeta(pBlob, nBlobSize, &type, &dims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } - if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf( + "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + pVector->type, + pVector->dims, + nDataSize, + vectorDataSize(pVector->type, pVector->dims) + ); return SQLITE_ERROR; } switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF32DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; case VECTOR_TYPE_FLOAT64: - vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; + case VECTOR_TYPE_1BIT: + vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: assert(0); @@ -211228,32 +211277,21 @@ int vectorParseSqliteBlobWithType( int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char **pzErrMsg) { const u8 *pBlob; - int nBlobSize; + size_t nBlobSize, nDataSize; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 != 0 ){ - // we have trailing byte with explicit type definition - *pType = pBlob[nBlobSize - 1]; - } else { - // else, fallback to FLOAT32 - *pType = VECTOR_TYPE_FLOAT32; - } - if( *pType == VECTOR_TYPE_FLOAT32 ){ - *pDims = nBlobSize / sizeof(float); - } else if( *pType == VECTOR_TYPE_FLOAT64 ){ - *pDims = nBlobSize / sizeof(double); - } else{ - *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: got %d, expected %d or %d", *pType, VECTOR_TYPE_FLOAT32, VECTOR_TYPE_FLOAT64); - return -1; + + if( vectorParseMeta(pBlob, nBlobSize, pType, pDims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } if( *pDims > MAX_VECTOR_SZ ){ *pzErrMsg = sqlite3_mprintf("vector: max size exceeded: %d > %d", *pDims, MAX_VECTOR_SZ); - return -1; + return SQLITE_ERROR; } - return 0; + return SQLITE_OK; } int detectTextVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pDims, char **pzErrMsg) { @@ -211350,21 +211388,55 @@ void vectorMarshalToText( } } -void vectorSerializeWithType( +static int vectorMetaSize(VectorType type, VectorDims dims){ + int nMetaSize = 0; + int nDataSize; + if( type == VECTOR_TYPE_FLOAT32 ){ + return 0; + }else if( type == VECTOR_TYPE_FLOAT64 ){ + return 1; + }else if( type == VECTOR_TYPE_1BIT ){ + nDataSize = vectorDataSize(type, dims); + nMetaSize++; // one byte which specify amount of leftover bits + if( nDataSize % 2 == 0 ){ + nMetaSize++; // pad "leftover-bits" byte to the even length + } + nMetaSize++; // one byte for vector type + return nMetaSize; + }else{ + assert( 0 ); + } +} + +static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){ + if( pVector->type == VECTOR_TYPE_FLOAT32 ){ + // no meta for f32 type as this is "default" vector type + }else if( pVector->type == VECTOR_TYPE_FLOAT64 ){ + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 1 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; + }else if( pVector->type == VECTOR_TYPE_1BIT ){ + assert( nBlobSize % 2 == 1 ); + assert( nBlobSize >= 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + }else{ + assert( 0 ); + } +} + +void vectorSerializeWithMeta( sqlite3_context *context, const Vector *pVector ){ unsigned char *pBlob; - size_t nBlobSize, nDataSize; + size_t nBlobSize, nDataSize, nMetaSize; assert( pVector->dims <= MAX_VECTOR_SZ ); nDataSize = vectorDataSize(pVector->type, pVector->dims); - nBlobSize = nDataSize; - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); - } - + nMetaSize = vectorMetaSize(pVector->type, pVector->dims); + nBlobSize = nDataSize + nMetaSize; if( nBlobSize == 0 ){ sqlite3_result_zeroblob(context, 0); return; @@ -211376,10 +211448,6 @@ void vectorSerializeWithType( return; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - pBlob[nBlobSize - 1] = pVector->type; - } - switch (pVector->type) { case VECTOR_TYPE_FLOAT32: vectorF32SerializeToBlob(pVector, pBlob, nDataSize); @@ -211387,9 +211455,13 @@ void vectorSerializeWithType( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; + case VECTOR_TYPE_1BIT: + vector1BitSerializeToBlob(pVector, pBlob, nDataSize); + break; default: assert(0); } + vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } @@ -211408,38 +211480,124 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32InitFromBlob(pVector, pBlob, nBlobSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64InitFromBlob(pVector, pBlob, nBlobSize); - break; - default: - assert(0); + pVector->data = (void*)pBlob; +} + +static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ + int i; + float *src; + + u8 *dst1Bit; + double *dstF64; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT32 ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); } } -void vectorConvert(const Vector *pFrom, Vector *pTo){ +static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ int i; - u8 *bitData; - float *floatData; + double *src; + + u8 *dst1Bit; + float *dstF32; assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT64 ); - if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ - floatData = pFrom->data; - bitData = pTo->data; + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ - bitData[i / 8] = 0; + dst1Bit[i / 8] = 0; } for(i = 0; i < pFrom->dims; i++){ - if( floatData[i] > 0 ){ - bitData[i / 8] |= (1 << (i & 7)); + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); } } }else{ - assert(0); + assert( 0 ); + } +} + +static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + + float *dstF32; + double *dstF64; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_1BIT ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF32[i] = +1; + }else{ + dstF32[i] = -1; + } + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF64[i] = +1; + }else{ + dstF64[i] = -1; + } + } + }else{ + assert( 0 ); + } +} + +void vectorConvert(const Vector *pFrom, Vector *pTo){ + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == pTo->type ){ + memcpy(pTo->data, pFrom->data, vectorDataSize(pFrom->type, pFrom->dims)); + return; + } + + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + vectorConvertFromF32(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + vectorConvertFromF64(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + vectorConvertFrom1Bit(pFrom, pTo); + }else{ + assert( 0 ); } } @@ -211454,31 +211612,49 @@ static void vectorFuncHintedType( sqlite3_context *context, int argc, sqlite3_value **argv, - int typeHint + int targetType ){ char *pzErrMsg = NULL; - Vector *pVector; - int type, dims; + Vector *pVector = NULL, *pTarget = NULL; + int type, dims, typeHint = VECTOR_TYPE_FLOAT32; if( argc < 1 ){ - return; + goto out; + } + // simplification in order to support only parsing from text to f32 and f64 vectors + if( targetType == VECTOR_TYPE_FLOAT64 ){ + typeHint = targetType; } if( detectVectorParameters(argv[0], typeHint, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free_vec; + goto out; + } + if( type == targetType ){ + vectorSerializeWithMeta(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, targetType, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorSerializeWithMeta(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorSerializeWithType(context, pVector); -out_free_vec: - vectorFree(pVector); } static void vector32Func( @@ -211496,6 +211672,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector1BitFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); +} + /* ** Implementation of vector_extract(X) function. */ @@ -211505,30 +211689,44 @@ static void vectorExtractFunc( sqlite3_value **argv ){ char *pzErrMsg = NULL; - Vector *pVector; + Vector *pVector = NULL, *pTarget = NULL; unsigned i; int type, dims; if( argc < 1 ){ - return; + goto out; } if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free; + goto out; + } + if( pVector->type == VECTOR_TYPE_FLOAT32 || pVector->type == VECTOR_TYPE_FLOAT64 ){ + vectorMarshalToText(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, VECTOR_TYPE_FLOAT32, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorMarshalToText(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorMarshalToText(context, pVector); -out_free: - vectorFree(pVector); } /* @@ -211612,6 +211810,7 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector, 1, 0, 0, vector32Func), FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), + FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), @@ -211750,6 +211949,20 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ return diff; } +void vector1BitDeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + memcpy(elems, pBlob, (pVector->dims + 7) / 8); +} + #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector1bit.c ******************************************/ @@ -213184,12 +213397,12 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): k must be a non-negative integer"); return SQLITE_ERROR; } - if( pIndex->nVectorDims != pVector->dims ){ + if( pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(search): dimensions are different: %d != %d", pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); + if( pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): vector type differs from column type: %d != %d", pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } @@ -213254,8 +213467,8 @@ int diskAnnInsert( *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVectorInRow->pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(insert): only f32 vectors are supported"); + if( pVectorInRow->pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): vector type differs from column type: %d != %d", pVectorInRow->pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } @@ -213703,11 +213916,6 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(float); - pVector->data = (void*)pBlob; -} - void vectorF32DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, @@ -213907,11 +214115,6 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(double); - pVector->data = (void*)pBlob; -} - void vectorF64DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, @@ -213960,6 +214163,7 @@ void vectorF64DeserializeFromBlob( ** ** libSQL vector search. */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "sqlite3.h" */ /* #include "vdbeInt.h" */ @@ -214309,14 +214513,16 @@ void vectorOutRowsFree(sqlite3 *db, VectorOutRows *pRows) { */ struct VectorColumnType { const char *zName; - int nBits; + int type; }; static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { - { "FLOAT32", 32 }, - { "FLOAT64", 64 }, - { "F32_BLOB", 32 }, - { "F64_BLOB", 64 } + { "FLOAT32", VECTOR_TYPE_FLOAT32 }, + { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, + { "FLOAT64", VECTOR_TYPE_FLOAT64 }, + { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, + { "FLOAT1BIT", VECTOR_TYPE_1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, }; /* @@ -214505,14 +214711,7 @@ int vectorIdxParseColumnType(const char *zType, int *pType, int *pDims, const ch } *pDims = dimensions; - if( VECTOR_COLUMN_TYPES[i].nBits == 32 ) { - *pType = VECTOR_TYPE_FLOAT32; - } else if( VECTOR_COLUMN_TYPES[i].nBits == 64 ) { - *pType = VECTOR_TYPE_FLOAT64; - } else { - *pErrMsg = "unsupported vector type"; - return -1; - } + *pType = VECTOR_COLUMN_TYPES[i].type; return 0; } *pErrMsg = "unexpected vector column type"; @@ -214823,7 +215022,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: %s: %s", pzErrMsg, zEmbeddingColumnTypeName); return CREATE_FAIL; } - // schema is locked while db is initializing and we need to just proceed here if( db->init.busy == 1 ){ return CREATE_OK; @@ -214903,11 +215101,8 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); - rc = SQLITE_ERROR; - goto out; - } + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + pVector = vectorAlloc(type, dims); if( pVector == NULL ){ rc = SQLITE_NOMEM_BKPT; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 8ceabfc713..3568559303 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85311,20 +85311,19 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerializeWithType(sqlite3_context *, const Vector *); +void vectorSerializeWithMeta(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); -void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); +void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); -void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); void vectorConvert(const Vector *, Vector *); @@ -210981,7 +210980,6 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: - assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211192,33 +211190,84 @@ static int vectorParseSqliteText( return -1; } +static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ + int nLeftoverBits; + + if( nBlobSize % 2 == 0 ){ + *pType = VECTOR_TYPE_FLOAT32; + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + return SQLITE_OK; + } + *pType = pBlob[nBlobSize - 1]; + nBlobSize--; + + if( *pType == VECTOR_TYPE_FLOAT32 ){ + if( nBlobSize % 4 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(float); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_FLOAT64 ){ + if( nBlobSize % 8 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + *pDims = nBlobSize / sizeof(double); + *pDataSize = nBlobSize; + }else if( *pType == VECTOR_TYPE_1BIT ){ + if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nLeftoverBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nLeftoverBits; + *pDataSize = (*pDims + 7) / 8; + }else{ + *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + return SQLITE_ERROR; + } + return SQLITE_OK; +} + int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ const unsigned char *pBlob; - size_t nBlobSize; + size_t nBlobSize, nDataSize; + int type, dims; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 == 1 ){ - nBlobSize--; + if( vectorParseMeta(pBlob, nBlobSize, &type, &dims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } - if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf( + "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + pVector->type, + pVector->dims, + nDataSize, + vectorDataSize(pVector->type, pVector->dims) + ); return SQLITE_ERROR; } switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF32DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; case VECTOR_TYPE_FLOAT64: - vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; + case VECTOR_TYPE_1BIT: + vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: assert(0); @@ -211228,32 +211277,21 @@ int vectorParseSqliteBlobWithType( int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char **pzErrMsg) { const u8 *pBlob; - int nBlobSize; + size_t nBlobSize, nDataSize; assert( sqlite3_value_type(arg) == SQLITE_BLOB ); pBlob = sqlite3_value_blob(arg); nBlobSize = sqlite3_value_bytes(arg); - if( nBlobSize % 2 != 0 ){ - // we have trailing byte with explicit type definition - *pType = pBlob[nBlobSize - 1]; - } else { - // else, fallback to FLOAT32 - *pType = VECTOR_TYPE_FLOAT32; - } - if( *pType == VECTOR_TYPE_FLOAT32 ){ - *pDims = nBlobSize / sizeof(float); - } else if( *pType == VECTOR_TYPE_FLOAT64 ){ - *pDims = nBlobSize / sizeof(double); - } else{ - *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: got %d, expected %d or %d", *pType, VECTOR_TYPE_FLOAT32, VECTOR_TYPE_FLOAT64); - return -1; + + if( vectorParseMeta(pBlob, nBlobSize, pType, pDims, &nDataSize, pzErrMsg) != SQLITE_OK ){ + return SQLITE_ERROR; } if( *pDims > MAX_VECTOR_SZ ){ *pzErrMsg = sqlite3_mprintf("vector: max size exceeded: %d > %d", *pDims, MAX_VECTOR_SZ); - return -1; + return SQLITE_ERROR; } - return 0; + return SQLITE_OK; } int detectTextVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pDims, char **pzErrMsg) { @@ -211350,21 +211388,55 @@ void vectorMarshalToText( } } -void vectorSerializeWithType( +static int vectorMetaSize(VectorType type, VectorDims dims){ + int nMetaSize = 0; + int nDataSize; + if( type == VECTOR_TYPE_FLOAT32 ){ + return 0; + }else if( type == VECTOR_TYPE_FLOAT64 ){ + return 1; + }else if( type == VECTOR_TYPE_1BIT ){ + nDataSize = vectorDataSize(type, dims); + nMetaSize++; // one byte which specify amount of leftover bits + if( nDataSize % 2 == 0 ){ + nMetaSize++; // pad "leftover-bits" byte to the even length + } + nMetaSize++; // one byte for vector type + return nMetaSize; + }else{ + assert( 0 ); + } +} + +static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){ + if( pVector->type == VECTOR_TYPE_FLOAT32 ){ + // no meta for f32 type as this is "default" vector type + }else if( pVector->type == VECTOR_TYPE_FLOAT64 ){ + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 1 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; + }else if( pVector->type == VECTOR_TYPE_1BIT ){ + assert( nBlobSize % 2 == 1 ); + assert( nBlobSize >= 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + }else{ + assert( 0 ); + } +} + +void vectorSerializeWithMeta( sqlite3_context *context, const Vector *pVector ){ unsigned char *pBlob; - size_t nBlobSize, nDataSize; + size_t nBlobSize, nDataSize, nMetaSize; assert( pVector->dims <= MAX_VECTOR_SZ ); nDataSize = vectorDataSize(pVector->type, pVector->dims); - nBlobSize = nDataSize; - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); - } - + nMetaSize = vectorMetaSize(pVector->type, pVector->dims); + nBlobSize = nDataSize + nMetaSize; if( nBlobSize == 0 ){ sqlite3_result_zeroblob(context, 0); return; @@ -211376,10 +211448,6 @@ void vectorSerializeWithType( return; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - pBlob[nBlobSize - 1] = pVector->type; - } - switch (pVector->type) { case VECTOR_TYPE_FLOAT32: vectorF32SerializeToBlob(pVector, pBlob, nDataSize); @@ -211387,9 +211455,13 @@ void vectorSerializeWithType( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; + case VECTOR_TYPE_1BIT: + vector1BitSerializeToBlob(pVector, pBlob, nDataSize); + break; default: assert(0); } + vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } @@ -211408,38 +211480,124 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32InitFromBlob(pVector, pBlob, nBlobSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64InitFromBlob(pVector, pBlob, nBlobSize); - break; - default: - assert(0); + pVector->data = (void*)pBlob; +} + +static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ + int i; + float *src; + + u8 *dst1Bit; + double *dstF64; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT32 ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); } } -void vectorConvert(const Vector *pFrom, Vector *pTo){ +static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ int i; - u8 *bitData; - float *floatData; + double *src; + + u8 *dst1Bit; + float *dstF32; assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT64 ); - if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ - floatData = pFrom->data; - bitData = pTo->data; + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = src[i]; + } + }else if( pTo->type == VECTOR_TYPE_1BIT ){ + dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ - bitData[i / 8] = 0; + dst1Bit[i / 8] = 0; } for(i = 0; i < pFrom->dims; i++){ - if( floatData[i] > 0 ){ - bitData[i / 8] |= (1 << (i & 7)); + if( src[i] > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); } } }else{ - assert(0); + assert( 0 ); + } +} + +static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + + float *dstF32; + double *dstF64; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_1BIT ); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF32[i] = +1; + }else{ + dstF32[i] = -1; + } + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){ + dstF64[i] = +1; + }else{ + dstF64[i] = -1; + } + } + }else{ + assert( 0 ); + } +} + +void vectorConvert(const Vector *pFrom, Vector *pTo){ + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == pTo->type ){ + memcpy(pTo->data, pFrom->data, vectorDataSize(pFrom->type, pFrom->dims)); + return; + } + + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + vectorConvertFromF32(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + vectorConvertFromF64(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + vectorConvertFrom1Bit(pFrom, pTo); + }else{ + assert( 0 ); } } @@ -211454,31 +211612,49 @@ static void vectorFuncHintedType( sqlite3_context *context, int argc, sqlite3_value **argv, - int typeHint + int targetType ){ char *pzErrMsg = NULL; - Vector *pVector; - int type, dims; + Vector *pVector = NULL, *pTarget = NULL; + int type, dims, typeHint = VECTOR_TYPE_FLOAT32; if( argc < 1 ){ - return; + goto out; + } + // simplification in order to support only parsing from text to f32 and f64 vectors + if( targetType == VECTOR_TYPE_FLOAT64 ){ + typeHint = targetType; } if( detectVectorParameters(argv[0], typeHint, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free_vec; + goto out; + } + if( type == targetType ){ + vectorSerializeWithMeta(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, targetType, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorSerializeWithMeta(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorSerializeWithType(context, pVector); -out_free_vec: - vectorFree(pVector); } static void vector32Func( @@ -211496,6 +211672,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector1BitFunc( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); +} + /* ** Implementation of vector_extract(X) function. */ @@ -211505,30 +211689,44 @@ static void vectorExtractFunc( sqlite3_value **argv ){ char *pzErrMsg = NULL; - Vector *pVector; + Vector *pVector = NULL, *pTarget = NULL; unsigned i; int type, dims; if( argc < 1 ){ - return; + goto out; } if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - return; + goto out; } pVector = vectorContextAlloc(context, type, dims); - if( pVector==NULL ){ - return; + if( pVector == NULL ){ + goto out; } if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); - goto out_free; + goto out; + } + if( pVector->type == VECTOR_TYPE_FLOAT32 || pVector->type == VECTOR_TYPE_FLOAT64 ){ + vectorMarshalToText(context, pVector); + }else{ + pTarget = vectorContextAlloc(context, VECTOR_TYPE_FLOAT32, dims); + if( pTarget == NULL ){ + goto out; + } + vectorConvert(pVector, pTarget); + vectorMarshalToText(context, pTarget); + } +out: + if( pVector != NULL ){ + vectorFree(pVector); + } + if( pTarget != NULL ){ + vectorFree(pTarget); } - vectorMarshalToText(context, pVector); -out_free: - vectorFree(pVector); } /* @@ -211612,6 +211810,7 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector, 1, 0, 0, vector32Func), FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), + FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), @@ -211750,6 +211949,20 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ return diff; } +void vector1BitDeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + memcpy(elems, pBlob, (pVector->dims + 7) / 8); +} + #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector1bit.c ******************************************/ @@ -213184,12 +213397,12 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): k must be a non-negative integer"); return SQLITE_ERROR; } - if( pIndex->nVectorDims != pVector->dims ){ + if( pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(search): dimensions are different: %d != %d", pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); + if( pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): vector type differs from column type: %d != %d", pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } @@ -213254,8 +213467,8 @@ int diskAnnInsert( *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); return SQLITE_ERROR; } - if( pVectorInRow->pVector->type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(insert): only f32 vectors are supported"); + if( pVectorInRow->pVector->type != pIndex->nNodeVectorType ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): vector type differs from column type: %d != %d", pVectorInRow->pVector->type, pIndex->nNodeVectorType); return SQLITE_ERROR; } @@ -213703,11 +213916,6 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(float); - pVector->data = (void*)pBlob; -} - void vectorF32DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, @@ -213907,11 +214115,6 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){ return sqrt(sum); } -void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - pVector->dims = nBlobSize / sizeof(double); - pVector->data = (void*)pBlob; -} - void vectorF64DeserializeFromBlob( Vector *pVector, const unsigned char *pBlob, @@ -213960,6 +214163,7 @@ void vectorF64DeserializeFromBlob( ** ** libSQL vector search. */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "sqlite3.h" */ /* #include "vdbeInt.h" */ @@ -214309,14 +214513,16 @@ void vectorOutRowsFree(sqlite3 *db, VectorOutRows *pRows) { */ struct VectorColumnType { const char *zName; - int nBits; + int type; }; static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { - { "FLOAT32", 32 }, - { "FLOAT64", 64 }, - { "F32_BLOB", 32 }, - { "F64_BLOB", 64 } + { "FLOAT32", VECTOR_TYPE_FLOAT32 }, + { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, + { "FLOAT64", VECTOR_TYPE_FLOAT64 }, + { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, + { "FLOAT1BIT", VECTOR_TYPE_1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, }; /* @@ -214505,14 +214711,7 @@ int vectorIdxParseColumnType(const char *zType, int *pType, int *pDims, const ch } *pDims = dimensions; - if( VECTOR_COLUMN_TYPES[i].nBits == 32 ) { - *pType = VECTOR_TYPE_FLOAT32; - } else if( VECTOR_COLUMN_TYPES[i].nBits == 64 ) { - *pType = VECTOR_TYPE_FLOAT64; - } else { - *pErrMsg = "unsupported vector type"; - return -1; - } + *pType = VECTOR_COLUMN_TYPES[i].type; return 0; } *pErrMsg = "unexpected vector column type"; @@ -214823,7 +215022,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: %s: %s", pzErrMsg, zEmbeddingColumnTypeName); return CREATE_FAIL; } - // schema is locked while db is initializing and we need to just proceed here if( db->init.busy == 1 ){ return CREATE_OK; @@ -214903,11 +215101,8 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( type != VECTOR_TYPE_FLOAT32 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): only f32 vectors are supported"); - rc = SQLITE_ERROR; - goto out; - } + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + pVector = vectorAlloc(type, dims); if( pVector == NULL ){ rc = SQLITE_NOMEM_BKPT; From 30a198e962d7e23ac7b7c01ff181452d540f91c8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:36:02 +0400 Subject: [PATCH 080/121] refine error messages --- libsql-sqlite3/src/vector.c | 10 +++++----- libsql-sqlite3/test/libsql_vector.test | 2 +- libsql-sqlite3/test/libsql_vector_index.test | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index b5c43901e8..6c4a424343 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -266,28 +266,28 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } nLeftoverBits = pBlob[nBlobSize - 1]; *pDims = nBlobSize * 8 - nLeftoverBits; *pDataSize = (*pDims + 7) / 8; }else{ - *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; } return SQLITE_OK; @@ -312,7 +312,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", pVector->type, pVector->dims, nDataSize, diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index 7afb0f8bb0..be2edc9397 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -137,7 +137,7 @@ do_test vector-1-func-errors { {vector: invalid float at position 0: '[1'} {vector: invalid float at position 2: '1.1.1'} {vector: must end with ']'} - {invalid vector: unexpected type: 0} + {vector: unexpected binary type: 0} {vector_distance_cos: vectors must have the same length: 3 != 2} {vector_distance_cos: vectors must have the same type: 1 != 2} }] diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 0756566914..242383e376 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -327,7 +327,7 @@ do_execsql_test vector-partial { 2 3 5 6 8 9 } -do_execsql_test vector-1bit-table { +do_execsql_test vector-1bit-index { CREATE TABLE t_1bit_table( v FLOAT1BIT(4) ); INSERT INTO t_1bit_table VALUES ( vector1bit('[1,-1,1,-1]') ); CREATE INDEX t_1bit_table_idx ON t_1bit_table( libsql_vector_idx(v) ); From 66d374f52bb1b907403f1dc6cb980dd5d0530de7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:45:01 +0400 Subject: [PATCH 081/121] specify binary format for vectors in comment --- libsql-sqlite3/src/vectorInt.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 350a9ae9bd..aaf219f52b 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -19,6 +19,26 @@ typedef u32 VectorDims; */ #define MAX_VECTOR_SZ 65536 +/* + * on-disk binary format for vector of different types: + * 1. float32 + * [data[0] as f32] [data[1] as f32] ... [data[dims - 1] as f32] [1 as u8]? + * - last 'type'-byte is optional for float32 vectors + * + * 2. float64 + * [data[0] as f64] [data[1] as f64] ... [data[dims - 1] as f64] [2 as u8] + * - last 'type'-byte is mandatory for float64 vectors + * + * 3. float1bit + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * - every data byte (except for the last) represents exactly 8 components of the vector + * - last data byte represents [1..8] components of the vector + * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) + * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - last 'type'-byte is mandatory for float1bit vectors +*/ + /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ From 3c09fcec950c3ab6101d46b9d679d95eb462e44a Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 17:47:44 +0400 Subject: [PATCH 082/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 30 +++++++++++++++---- libsql-ffi/bundled/src/sqlite3.c | 30 +++++++++++++++---- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 3568559303..3bac75359b 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85238,6 +85238,26 @@ typedef u32 VectorDims; */ #define MAX_VECTOR_SZ 65536 +/* + * on-disk binary format for vector of different types: + * 1. float32 + * [data[0] as f32] [data[1] as f32] ... [data[dims - 1] as f32] [1 as u8]? + * - last 'type'-byte is optional for float32 vectors + * + * 2. float64 + * [data[0] as f64] [data[1] as f64] ... [data[dims - 1] as f64] [2 as u8] + * - last 'type'-byte is mandatory for float64 vectors + * + * 3. float1bit + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * - every data byte (except for the last) represents exactly 8 components of the vector + * - last data byte represents [1..8] components of the vector + * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) + * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - last 'type'-byte is mandatory for float1bit vectors +*/ + /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ @@ -211204,28 +211224,28 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } nLeftoverBits = pBlob[nBlobSize - 1]; *pDims = nBlobSize * 8 - nLeftoverBits; *pDataSize = (*pDims + 7) / 8; }else{ - *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; } return SQLITE_OK; @@ -211250,7 +211270,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", pVector->type, pVector->dims, nDataSize, diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 3568559303..3bac75359b 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85238,6 +85238,26 @@ typedef u32 VectorDims; */ #define MAX_VECTOR_SZ 65536 +/* + * on-disk binary format for vector of different types: + * 1. float32 + * [data[0] as f32] [data[1] as f32] ... [data[dims - 1] as f32] [1 as u8]? + * - last 'type'-byte is optional for float32 vectors + * + * 2. float64 + * [data[0] as f64] [data[1] as f64] ... [data[dims - 1] as f64] [2 as u8] + * - last 'type'-byte is mandatory for float64 vectors + * + * 3. float1bit + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * - every data byte (except for the last) represents exactly 8 components of the vector + * - last data byte represents [1..8] components of the vector + * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) + * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - last 'type'-byte is mandatory for float1bit vectors +*/ + /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ @@ -211204,28 +211224,28 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } nLeftoverBits = pBlob[nBlobSize - 1]; *pDims = nBlobSize * 8 - nLeftoverBits; *pDataSize = (*pDims + 7) / 8; }else{ - *pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType); + *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; } return SQLITE_OK; @@ -211250,7 +211270,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", pVector->type, pVector->dims, nDataSize, From 2d325ba8f6c7d74cf3733e134ce952e30f7961c7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 18:46:14 +0400 Subject: [PATCH 083/121] windows compiler complains about operations with void* pointers - error C2036: 'void *': unknown size ... --- .../bundled/SQLite3MultipleCiphers/src/sqlite3.c | 16 ++++++++-------- libsql-ffi/bundled/src/sqlite3.c | 16 ++++++++-------- libsql-sqlite3/src/vectordiskann.c | 16 ++++++++-------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 1fd0252f0d..2dcd939f21 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212669,7 +212669,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -212687,7 +212687,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -212850,8 +212850,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -212872,8 +212872,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -212881,8 +212881,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 1fd0252f0d..2dcd939f21 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212669,7 +212669,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -212687,7 +212687,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -212850,8 +212850,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -212872,8 +212872,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -212881,8 +212881,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index a6c279b259..7b29c6f50f 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -913,7 +913,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -931,7 +931,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -1094,8 +1094,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -1116,8 +1116,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -1125,8 +1125,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node From c75354e6b23724c00adddef748afa421aeca937d Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:11:18 +0400 Subject: [PATCH 084/121] generate uniform values from [-1..1] in benchmark workloads --- libsql-sqlite3/benchmark/workload.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libsql-sqlite3/benchmark/workload.py b/libsql-sqlite3/benchmark/workload.py index 2d413531fa..728e375933 100644 --- a/libsql-sqlite3/benchmark/workload.py +++ b/libsql-sqlite3/benchmark/workload.py @@ -10,10 +10,10 @@ def recall_uniform(dim, n, q): print(f'CREATE TABLE queries ( emb FLOAT32({dim}) );') print(f'BEGIN TRANSACTION;') for i in range(n): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'INSERT INTO data VALUES ({i}, vector(\'{vector}\'));') for i in range(q): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'INSERT INTO queries VALUES (vector(\'{vector}\'));') print(f'COMMIT;') print('---insert everything') @@ -29,7 +29,7 @@ def recall_normal(dim, n, q): vector = f"[{','.join(map(str, np.random.uniform(size=64)))}]" print(f'INSERT INTO data VALUES ({i}, \'{vector}\');') for i in range(q): - vector = f"[{','.join(map(str, np.random.uniform(size=64)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=64)))}]" print(f'INSERT INTO queries VALUES (\'{vector}\');') print(f'COMMIT;') print('---insert everything') @@ -40,7 +40,7 @@ def no_vectors(n, q): print('PRAGMA journal_mode=WAL;') print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, value TEXT );') for i in range(n): - vector = f"[{','.join(map(str, np.random.uniform(size=64)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=64)))}]" print(f'INSERT INTO x VALUES ({i}, \'{vector}\');') print('---inserts') for i in range(q): @@ -54,11 +54,11 @@ def bruteforce(dim, n, q): print('PRAGMA journal_mode=WAL;') print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, embedding FLOAT32({dim}) );') for i in range(n): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'INSERT INTO x VALUES ({i}, vector(\'{vector}\'));') print('---inserts') for i in range(q): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'SELECT id FROM x ORDER BY vector_distance_cos(embedding, vector(\'{vector}\')) LIMIT 1;') print('---search') @@ -68,13 +68,13 @@ def diskann(dim, n, q): q = int(q) print('PRAGMA journal_mode=WAL;') print(f'CREATE TABLE x ( id INTEGER PRIMARY KEY, embedding FLOAT32({dim}) );') - print(f'CREATE INDEX x_idx ON x( libsql_vector_idx(embedding) );') + print(f"CREATE INDEX x_idx ON x( libsql_vector_idx(embedding) );") for i in range(n): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'INSERT INTO x VALUES ({i}, vector(\'{vector}\'));') print('---inserts') for i in range(q): - vector = f"[{','.join(map(str, np.random.uniform(size=dim)))}]" + vector = f"[{','.join(map(str, np.random.uniform(-1, 1, size=dim)))}]" print(f'SELECT id FROM vector_top_k(\'x_idx\', vector(\'{vector}\'), 1);') print('---search') From 96d5ca9235efea77ebd807e6241e3f9d1e9d8703 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:11:49 +0400 Subject: [PATCH 085/121] fix formatting --- libsql-sqlite3/src/vector.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 6c4a424343..c469fabdd7 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -312,7 +312,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %u != %u", pVector->type, pVector->dims, nDataSize, From db154137ad51d6e6cbaedbe1ee7521738203088c Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:13:46 +0400 Subject: [PATCH 086/121] fix format specifier + build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 2 +- libsql-ffi/bundled/src/sqlite3.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 3bac75359b..ad032de20e 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -211270,7 +211270,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %u != %u", pVector->type, pVector->dims, nDataSize, diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 3bac75359b..ad032de20e 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -211270,7 +211270,7 @@ int vectorParseSqliteBlobWithType( if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){ *pzErrMsg = sqlite3_mprintf( - "vector: unexpected data part size: type=%d, dims=%d, %ull != %ull", + "vector: unexpected data part size: type=%d, dims=%d, %u != %u", pVector->type, pVector->dims, nDataSize, From 7cafb80b461eb24724a95a29da9a5bfa107747de Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:24:46 +0400 Subject: [PATCH 087/121] use float1bit instead of 1bit everywhere in the code in index settings --- libsql-sqlite3/src/vector.c | 30 ++++++++++---------- libsql-sqlite3/src/vector1bit.c | 10 +++---- libsql-sqlite3/src/vectorIndex.c | 15 +++++----- libsql-sqlite3/src/vectorInt.h | 6 ++-- libsql-sqlite3/src/vectordiskann.c | 4 +-- libsql-sqlite3/test/libsql_vector_index.test | 13 +++++++-- 6 files changed, 44 insertions(+), 34 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index c469fabdd7..01bf402aa0 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -41,7 +41,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; default: assert(0); @@ -114,7 +114,7 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); @@ -278,7 +278,7 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; - }else if( *pType == VECTOR_TYPE_1BIT ){ + }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; @@ -328,7 +328,7 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT64: vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: @@ -426,7 +426,7 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; default: @@ -457,7 +457,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 0; }else if( type == VECTOR_TYPE_FLOAT64 ){ return 1; - }else if( type == VECTOR_TYPE_1BIT ){ + }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); nMetaSize++; // one byte which specify amount of leftover bits if( nDataSize % 2 == 0 ){ @@ -477,10 +477,10 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nDataSize % 2 == 0 ); assert( nBlobSize == nDataSize + 1 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - }else if( pVector->type == VECTOR_TYPE_1BIT ){ + }else if( pVector->type == VECTOR_TYPE_FLOAT1BIT ){ assert( nBlobSize % 2 == 1 ); assert( nBlobSize >= 3 ); - pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; }else{ assert( 0 ); @@ -517,7 +517,7 @@ void vectorSerializeWithMeta( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitSerializeToBlob(pVector, pBlob, nDataSize); break; default: @@ -533,7 +533,7 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); @@ -562,7 +562,7 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF64[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -594,7 +594,7 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF32[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -618,7 +618,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); assert( pFrom->type != pTo->type ); - assert( pFrom->type == VECTOR_TYPE_1BIT ); + assert( pFrom->type == VECTOR_TYPE_FLOAT1BIT ); src = pFrom->data; if( pTo->type == VECTOR_TYPE_FLOAT32 ){ @@ -656,7 +656,7 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); - }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); }else{ assert( 0 ); @@ -739,7 +739,7 @@ static void vector1BitFunc( int argc, sqlite3_value **argv ){ - vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT1BIT); } /* diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index b80c166522..86e367c2de 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -39,7 +39,7 @@ void vector1BitDump(const Vector *pVec){ u8 *elems = pVec->data; unsigned i; - assert( pVec->type == VECTOR_TYPE_1BIT ); + assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ @@ -61,7 +61,7 @@ size_t vector1BitSerializeToBlob( u8 *pPtr = pBlob; unsigned i; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); @@ -108,8 +108,8 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int i, len8, len32, offset8; assert( v1->dims == v2->dims ); - assert( v1->type == VECTOR_TYPE_1BIT ); - assert( v2->type == VECTOR_TYPE_1BIT ); + assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); + assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); len8 = (v1->dims + 7) / 8; len32 = v1->dims / 32; @@ -131,7 +131,7 @@ void vector1BitDeserializeFromBlob( ){ u8 *elems = pVector->data; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 6413bf0822..a9801c4e60 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -382,8 +382,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, { "FLOAT64", VECTOR_TYPE_FLOAT64 }, { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, - { "FLOAT1BIT", VECTOR_TYPE_1BIT }, - { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, + { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, }; /* @@ -399,10 +399,11 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, @@ -962,7 +963,7 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT1BIT ); pVector = vectorAlloc(type, dims); if( pVector == NULL ){ diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index aaf219f52b..a17ff1d59a 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -42,9 +42,9 @@ typedef u32 VectorDims; /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ -#define VECTOR_TYPE_FLOAT32 1 -#define VECTOR_TYPE_FLOAT64 2 -#define VECTOR_TYPE_1BIT 3 +#define VECTOR_TYPE_FLOAT32 1 +#define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_FLOAT1BIT 3 #define VECTOR_FLAGS_STATIC 1 diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index c28cf075e9..c6e2d5156f 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -505,7 +505,7 @@ int diskAnnCreateIndex( } } neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); - if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + if( neighbours == VECTOR_TYPE_FLOAT1BIT && metric != VECTOR_METRIC_TYPE_COS ){ *pzErrMsg = "1-bit compression available only for cosine metric"; return SQLITE_ERROR; } @@ -1749,7 +1749,7 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); }else{ diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 242383e376..f066275833 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -275,7 +275,7 @@ do_execsql_test vector-transaction { do_execsql_test vector-1bit { CREATE TABLE t_1bit( v FLOAT32(3) ); - CREATE INDEX t_1bit_idx ON t_1bit( libsql_vector_idx(v, 'compress_neighbors=1bit') ); + CREATE INDEX t_1bit_idx ON t_1bit( libsql_vector_idx(v, 'compress_neighbors=float1bit') ); INSERT INTO t_1bit VALUES (vector('[-1,-1,1]')); INSERT INTO t_1bit VALUES (vector('[-1,1,-1.5]')); INSERT INTO t_1bit VALUES (vector('[1,-1,-1]')); @@ -285,7 +285,7 @@ do_execsql_test vector-1bit { do_execsql_test vector-all-params { CREATE TABLE t_all_params ( emb FLOAT32(2) ); - CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6', 'compress_neighbors=1bit')); + CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6', 'compress_neighbors=float1bit')); INSERT INTO t_all_params VALUES (vector('[1,2]')), (vector('[3,4]')); SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} @@ -336,6 +336,15 @@ do_execsql_test vector-1bit-index { SELECT * FROM vector_top_k('t_1bit_table_idx', vector1bit('[10,-10,-20,20]'), 4); } {3 1 2} +do_execsql_test vector-f64-compress-f32 { + CREATE TABLE t_f64_f32( v FLOAT64(4) ); + CREATE INDEX t_f64_f32_idx ON t_f64_f32( libsql_vector_idx(v, 'compress_neighbors=float32') ); + INSERT INTO t_f64_f32 VALUES ( vector64('[1,-1,1,-1]') ); + INSERT INTO t_f64_f32 VALUES ( vector64('[-1,1,1,-1]') ); + INSERT INTO t_f64_f32 VALUES ( vector64('[1,-1,-1,1]') ); + SELECT * FROM vector_top_k('t_f64_f32_idx', vector64('[10,-10,-20,20]'), 4); +} {3 1 2} + proc error_messages {sql} { set ret "" catch { From acafcc8277958f6d08674ec32c858a5b4eb94eae Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:26:32 +0400 Subject: [PATCH 088/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 65 ++++++++++--------- libsql-ffi/bundled/src/sqlite3.c | 65 ++++++++++--------- 2 files changed, 66 insertions(+), 64 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index ad032de20e..8895733cb8 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85261,9 +85261,9 @@ typedef u32 VectorDims; /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ -#define VECTOR_TYPE_FLOAT32 1 -#define VECTOR_TYPE_FLOAT64 2 -#define VECTOR_TYPE_1BIT 3 +#define VECTOR_TYPE_FLOAT32 1 +#define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_FLOAT1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -210999,7 +210999,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; default: assert(0); @@ -211072,7 +211072,7 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); @@ -211236,7 +211236,7 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; - }else if( *pType == VECTOR_TYPE_1BIT ){ + }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; @@ -211286,7 +211286,7 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT64: vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: @@ -211384,7 +211384,7 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; default: @@ -211415,7 +211415,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 0; }else if( type == VECTOR_TYPE_FLOAT64 ){ return 1; - }else if( type == VECTOR_TYPE_1BIT ){ + }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); nMetaSize++; // one byte which specify amount of leftover bits if( nDataSize % 2 == 0 ){ @@ -211435,10 +211435,10 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nDataSize % 2 == 0 ); assert( nBlobSize == nDataSize + 1 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - }else if( pVector->type == VECTOR_TYPE_1BIT ){ + }else if( pVector->type == VECTOR_TYPE_FLOAT1BIT ){ assert( nBlobSize % 2 == 1 ); assert( nBlobSize >= 3 ); - pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; }else{ assert( 0 ); @@ -211475,7 +211475,7 @@ void vectorSerializeWithMeta( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitSerializeToBlob(pVector, pBlob, nDataSize); break; default: @@ -211491,7 +211491,7 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); @@ -211520,7 +211520,7 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF64[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -211552,7 +211552,7 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF32[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -211576,7 +211576,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); assert( pFrom->type != pTo->type ); - assert( pFrom->type == VECTOR_TYPE_1BIT ); + assert( pFrom->type == VECTOR_TYPE_FLOAT1BIT ); src = pFrom->data; if( pTo->type == VECTOR_TYPE_FLOAT32 ){ @@ -211614,7 +211614,7 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); - }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); }else{ assert( 0 ); @@ -211697,7 +211697,7 @@ static void vector1BitFunc( int argc, sqlite3_value **argv ){ - vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT1BIT); } /* @@ -211884,7 +211884,7 @@ void vector1BitDump(const Vector *pVec){ u8 *elems = pVec->data; unsigned i; - assert( pVec->type == VECTOR_TYPE_1BIT ); + assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ @@ -211906,7 +211906,7 @@ size_t vector1BitSerializeToBlob( u8 *pPtr = pBlob; unsigned i; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); @@ -211953,8 +211953,8 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int i, len8, len32, offset8; assert( v1->dims == v2->dims ); - assert( v1->type == VECTOR_TYPE_1BIT ); - assert( v2->type == VECTOR_TYPE_1BIT ); + assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); + assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); len8 = (v1->dims + 7) / 8; len32 = v1->dims / 32; @@ -211976,7 +211976,7 @@ void vector1BitDeserializeFromBlob( ){ u8 *elems = pVector->data; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); @@ -212494,7 +212494,7 @@ int diskAnnCreateIndex( } } neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); - if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + if( neighbours == VECTOR_TYPE_FLOAT1BIT && metric != VECTOR_METRIC_TYPE_COS ){ *pzErrMsg = "1-bit compression available only for cosine metric"; return SQLITE_ERROR; } @@ -213738,7 +213738,7 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); }else{ @@ -214541,8 +214541,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, { "FLOAT64", VECTOR_TYPE_FLOAT64 }, { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, - { "FLOAT1BIT", VECTOR_TYPE_1BIT }, - { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, + { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, }; /* @@ -214558,10 +214558,11 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, @@ -215121,7 +215122,7 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT1BIT ); pVector = vectorAlloc(type, dims); if( pVector == NULL ){ diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index ad032de20e..8895733cb8 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85261,9 +85261,9 @@ typedef u32 VectorDims; /* * Enumerate of supported vector types (0 omitted intentionally as we can use zero as "undefined" value) */ -#define VECTOR_TYPE_FLOAT32 1 -#define VECTOR_TYPE_FLOAT64 2 -#define VECTOR_TYPE_1BIT 3 +#define VECTOR_TYPE_FLOAT32 1 +#define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_FLOAT1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -210999,7 +210999,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; default: assert(0); @@ -211072,7 +211072,7 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); @@ -211236,7 +211236,7 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; - }else if( *pType == VECTOR_TYPE_1BIT ){ + }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; @@ -211286,7 +211286,7 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT64: vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize); return 0; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; default: @@ -211384,7 +211384,7 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; default: @@ -211415,7 +211415,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 0; }else if( type == VECTOR_TYPE_FLOAT64 ){ return 1; - }else if( type == VECTOR_TYPE_1BIT ){ + }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); nMetaSize++; // one byte which specify amount of leftover bits if( nDataSize % 2 == 0 ){ @@ -211435,10 +211435,10 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nDataSize % 2 == 0 ); assert( nBlobSize == nDataSize + 1 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - }else if( pVector->type == VECTOR_TYPE_1BIT ){ + }else if( pVector->type == VECTOR_TYPE_FLOAT1BIT ){ assert( nBlobSize % 2 == 1 ); assert( nBlobSize >= 3 ); - pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT; + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; }else{ assert( 0 ); @@ -211475,7 +211475,7 @@ void vectorSerializeWithMeta( case VECTOR_TYPE_FLOAT64: vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: vector1BitSerializeToBlob(pVector, pBlob, nDataSize); break; default: @@ -211491,7 +211491,7 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_1BIT: + case VECTOR_TYPE_FLOAT1BIT: return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); @@ -211520,7 +211520,7 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF64[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -211552,7 +211552,7 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){ for(i = 0; i < pFrom->dims; i++){ dstF32[i] = src[i]; } - }else if( pTo->type == VECTOR_TYPE_1BIT ){ + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ dst1Bit = pTo->data; for(i = 0; i < pFrom->dims; i += 8){ dst1Bit[i / 8] = 0; @@ -211576,7 +211576,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); assert( pFrom->type != pTo->type ); - assert( pFrom->type == VECTOR_TYPE_1BIT ); + assert( pFrom->type == VECTOR_TYPE_FLOAT1BIT ); src = pFrom->data; if( pTo->type == VECTOR_TYPE_FLOAT32 ){ @@ -211614,7 +211614,7 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); - }else if( pFrom->type == VECTOR_TYPE_1BIT ){ + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); }else{ assert( 0 ); @@ -211697,7 +211697,7 @@ static void vector1BitFunc( int argc, sqlite3_value **argv ){ - vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT); + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT1BIT); } /* @@ -211884,7 +211884,7 @@ void vector1BitDump(const Vector *pVec){ u8 *elems = pVec->data; unsigned i; - assert( pVec->type == VECTOR_TYPE_1BIT ); + assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ @@ -211906,7 +211906,7 @@ size_t vector1BitSerializeToBlob( u8 *pPtr = pBlob; unsigned i; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); @@ -211953,8 +211953,8 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int i, len8, len32, offset8; assert( v1->dims == v2->dims ); - assert( v1->type == VECTOR_TYPE_1BIT ); - assert( v2->type == VECTOR_TYPE_1BIT ); + assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); + assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); len8 = (v1->dims + 7) / 8; len32 = v1->dims / 32; @@ -211976,7 +211976,7 @@ void vector1BitDeserializeFromBlob( ){ u8 *elems = pVector->data; - assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); @@ -212494,7 +212494,7 @@ int diskAnnCreateIndex( } } neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); - if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + if( neighbours == VECTOR_TYPE_FLOAT1BIT && metric != VECTOR_METRIC_TYPE_COS ){ *pzErrMsg = "1-bit compression available only for cosine metric"; return SQLITE_ERROR; } @@ -213738,7 +213738,7 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); }else{ @@ -214541,8 +214541,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F32_BLOB", VECTOR_TYPE_FLOAT32 }, { "FLOAT64", VECTOR_TYPE_FLOAT64 }, { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, - { "FLOAT1BIT", VECTOR_TYPE_1BIT }, - { "F1BIT_BLOB", VECTOR_TYPE_1BIT }, + { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, + { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, }; /* @@ -214558,10 +214558,11 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, @@ -215121,7 +215122,7 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_1BIT ); + assert( type == VECTOR_TYPE_FLOAT32 || type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT1BIT ); pVector = vectorAlloc(type, dims); if( pVector == NULL ){ From 7a0fa8c19fd4d0349cf970efb09773f93bd0b021 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 11:33:58 +0400 Subject: [PATCH 089/121] rename vector1bit.c -> vectorfloat1bit.c --- .../SQLite3MultipleCiphers/src/sqlite3.c | 289 +++++++++--------- libsql-ffi/bundled/src/sqlite3.c | 289 +++++++++--------- libsql-sqlite3/Makefile.in | 8 +- libsql-sqlite3/src/vectorIndex.c | 1 - .../src/{vector1bit.c => vectorfloat1bit.c} | 0 libsql-sqlite3/tool/mksqlite3c.tcl | 2 +- libsql-sqlite3/tool/showwal | Bin 0 -> 54760 bytes 7 files changed, 293 insertions(+), 296 deletions(-) rename libsql-sqlite3/src/{vector1bit.c => vectorfloat1bit.c} (100%) create mode 100755 libsql-sqlite3/tool/showwal diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 8895733cb8..477758b4cc 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -211842,150 +211842,6 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ -/************** Begin file vector1bit.c **************************************/ -/* -** 2024-07-04 -** -** Copyright 2024 the libSQL authors -** -** Permission is hereby granted, free of charge, to any person obtaining a copy of -** this software and associated documentation files (the "Software"), to deal in -** the Software without restriction, including without limitation the rights to -** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -** the Software, and to permit persons to whom the Software is furnished to do so, -** subject to the following conditions: -** -** The above copyright notice and this permission notice shall be included in all -** copies or substantial portions of the Software. -** -** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -** -****************************************************************************** -** -** 1-bit vector format utilities. -*/ -#ifndef SQLITE_OMIT_VECTOR -/* #include "sqliteInt.h" */ - -/* #include "vectorInt.h" */ - -/* #include */ - -/************************************************************************** -** Utility routines for debugging -**************************************************************************/ - -void vector1BitDump(const Vector *pVec){ - u8 *elems = pVec->data; - unsigned i; - - assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); - - printf("f1bit: ["); - for(i = 0; i < pVec->dims; i++){ - printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); - } - printf("]\n"); -} - -/************************************************************************** -** Utility routines for vector serialization and deserialization -**************************************************************************/ - -size_t vector1BitSerializeToBlob( - const Vector *pVector, - unsigned char *pBlob, - size_t nBlobSize -){ - u8 *elems = pVector->data; - u8 *pPtr = pBlob; - unsigned i; - - assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); - - for(i = 0; i < (pVector->dims + 7) / 8; i++){ - pPtr[i] = elems[i]; - } - return (pVector->dims + 7) / 8; -} - -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - -static inline int sqlite3PopCount32(u32 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcount(a); -#else - return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; -#endif -} - -int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ - int diff = 0; - u8 *e1U8 = v1->data; - u32 *e1U32 = v1->data; - u8 *e2U8 = v2->data; - u32 *e2U32 = v2->data; - int i, len8, len32, offset8; - - assert( v1->dims == v2->dims ); - assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); - assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); - - len8 = (v1->dims + 7) / 8; - len32 = v1->dims / 32; - offset8 = len32 * 4; - - for(i = 0; i < len32; i++){ - diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); - } - for(i = offset8; i < len8; i++){ - diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); - } - return diff; -} - -void vector1BitDeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - u8 *elems = pVector->data; - - assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); - assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); - - memcpy(elems, pBlob, (pVector->dims + 7) / 8); -} - -#endif /* !defined(SQLITE_OMIT_VECTOR) */ - -/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -213765,6 +213621,150 @@ void diskAnnCloseIndex(DiskAnnIndex *pIndex){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vectordiskann.c ***************************************/ +/************** Begin file vectorfloat1bit.c *********************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); + + printf("f1bit: ["); + for(i = 0; i < pVec->dims; i++){ + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("]\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); + assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +void vector1BitDeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + memcpy(elems, pBlob, (pVector->dims + 7) / 8); +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vectorfloat1bit.c *************************************/ /************** Begin file vectorfloat32.c ***********************************/ /* ** 2024-07-04 @@ -214183,7 +214183,6 @@ void vectorF64DeserializeFromBlob( ** ** libSQL vector search. */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "sqlite3.h" */ /* #include "vdbeInt.h" */ diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 8895733cb8..477758b4cc 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -211842,150 +211842,6 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ -/************** Begin file vector1bit.c **************************************/ -/* -** 2024-07-04 -** -** Copyright 2024 the libSQL authors -** -** Permission is hereby granted, free of charge, to any person obtaining a copy of -** this software and associated documentation files (the "Software"), to deal in -** the Software without restriction, including without limitation the rights to -** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -** the Software, and to permit persons to whom the Software is furnished to do so, -** subject to the following conditions: -** -** The above copyright notice and this permission notice shall be included in all -** copies or substantial portions of the Software. -** -** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -** -****************************************************************************** -** -** 1-bit vector format utilities. -*/ -#ifndef SQLITE_OMIT_VECTOR -/* #include "sqliteInt.h" */ - -/* #include "vectorInt.h" */ - -/* #include */ - -/************************************************************************** -** Utility routines for debugging -**************************************************************************/ - -void vector1BitDump(const Vector *pVec){ - u8 *elems = pVec->data; - unsigned i; - - assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); - - printf("f1bit: ["); - for(i = 0; i < pVec->dims; i++){ - printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); - } - printf("]\n"); -} - -/************************************************************************** -** Utility routines for vector serialization and deserialization -**************************************************************************/ - -size_t vector1BitSerializeToBlob( - const Vector *pVector, - unsigned char *pBlob, - size_t nBlobSize -){ - u8 *elems = pVector->data; - u8 *pPtr = pBlob; - unsigned i; - - assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); - - for(i = 0; i < (pVector->dims + 7) / 8; i++){ - pPtr[i] = elems[i]; - } - return (pVector->dims + 7) / 8; -} - -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - -static inline int sqlite3PopCount32(u32 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcount(a); -#else - return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; -#endif -} - -int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ - int diff = 0; - u8 *e1U8 = v1->data; - u32 *e1U32 = v1->data; - u8 *e2U8 = v2->data; - u32 *e2U32 = v2->data; - int i, len8, len32, offset8; - - assert( v1->dims == v2->dims ); - assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); - assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); - - len8 = (v1->dims + 7) / 8; - len32 = v1->dims / 32; - offset8 = len32 * 4; - - for(i = 0; i < len32; i++){ - diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); - } - for(i = offset8; i < len8; i++){ - diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); - } - return diff; -} - -void vector1BitDeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - u8 *elems = pVector->data; - - assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); - assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); - - memcpy(elems, pBlob, (pVector->dims + 7) / 8); -} - -#endif /* !defined(SQLITE_OMIT_VECTOR) */ - -/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -213765,6 +213621,150 @@ void diskAnnCloseIndex(DiskAnnIndex *pIndex){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vectordiskann.c ***************************************/ +/************** Begin file vectorfloat1bit.c *********************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT1BIT ); + + printf("f1bit: ["); + for(i = 0; i < pVec->dims; i++){ + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("]\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT1BIT ); + assert( v2->type == VECTOR_TYPE_FLOAT1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +void vector1BitDeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + + assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + memcpy(elems, pBlob, (pVector->dims + 7) / 8); +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vectorfloat1bit.c *************************************/ /************** Begin file vectorfloat32.c ***********************************/ /* ** 2024-07-04 @@ -214183,7 +214183,6 @@ void vectorF64DeserializeFromBlob( ** ** libSQL vector search. */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "sqlite3.h" */ /* #include "vdbeInt.h" */ diff --git a/libsql-sqlite3/Makefile.in b/libsql-sqlite3/Makefile.in index 0afadd458f..db1e2c55ce 100644 --- a/libsql-sqlite3/Makefile.in +++ b/libsql-sqlite3/Makefile.in @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \ sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \ table.lo threads.lo tokenize.lo treeview.lo trigger.lo \ update.lo userauth.lo upsert.lo util.lo vacuum.lo \ - vector.lo vectorfloat32.lo vectorfloat64.lo vector1bit.lo \ + vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo \ vectorIndex.lo vectordiskann.lo vectorvtab.lo \ vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \ vdbetrace.lo vdbevtab.lo \ @@ -302,8 +302,8 @@ SRC = \ $(TOP)/src/util.c \ $(TOP)/src/vacuum.c \ $(TOP)/src/vector.c \ - $(TOP)/src/vector1bit.c \ $(TOP)/src/vectorInt.h \ + $(TOP)/src/vectorfloat1bit.c \ $(TOP)/src/vectorfloat32.c \ $(TOP)/src/vectorfloat64.c \ $(TOP)/src/vectorIndexInt.h \ @@ -1139,8 +1139,8 @@ vacuum.lo: $(TOP)/src/vacuum.c $(HDR) vector.lo: $(TOP)/src/vector.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector.c -vector1bit.lo: $(TOP)/src/vector1bit.c $(HDR) - $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector1bit.c +vectorfloat1bit.lo: $(TOP)/src/vectorfloat1bit.c $(HDR) + $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat1bit.c vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index a9801c4e60..a253d0d12b 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -24,7 +24,6 @@ ** ** libSQL vector search. */ -#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "sqlite3.h" #include "vdbeInt.h" diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vectorfloat1bit.c similarity index 100% rename from libsql-sqlite3/src/vector1bit.c rename to libsql-sqlite3/src/vectorfloat1bit.c diff --git a/libsql-sqlite3/tool/mksqlite3c.tcl b/libsql-sqlite3/tool/mksqlite3c.tcl index 3a04459e31..fe76abc0a8 100644 --- a/libsql-sqlite3/tool/mksqlite3c.tcl +++ b/libsql-sqlite3/tool/mksqlite3c.tcl @@ -468,8 +468,8 @@ set flist { json.c vector.c - vector1bit.c vectordiskann.c + vectorfloat1bit.c vectorfloat32.c vectorfloat64.c vectorIndex.c diff --git a/libsql-sqlite3/tool/showwal b/libsql-sqlite3/tool/showwal new file mode 100755 index 0000000000000000000000000000000000000000..0687a4159570659c43bc62bd066c82c673c18e26 GIT binary patch literal 54760 zcmeFad3aPs)<0f%?d~ky2}vMqrXdN0EhG>ifdq&I>4cyli$TN?Ivo-TNlZEzFaac> z_Jk<7`>NoKGcNP0;4%!NGDLRA6_;@z6x5EOATA(`=Jz>u>)v#teBbx^KEK}|zvuDh zxqa)LI_K1>b55PATYb}a_0+;kWJzM?p|FXJpr~X6Gg(xmydny?OlFxZ5}$+FAf^LW zInJn`Y+O)l9Pi0puXce~a3#KYD*VQ^pMhrP)zpx##CM_No^Yeays{_*CO(ad0`7jz z6%CnJv)*bi&5*;oU5ji6-n@qLk!|90jQaQ)U9U{^J5%tRd}J=-OA&l2g3r935{59Z zraj3fbY_}zyIs4f&ARiI3;OQYJjkTOyiyx<%>LY8`hU~jOu-j1#*owfdX<4@=GAO( z7Wl|6|J%cK^W~zw3;7`^;?uke^RkhZ6^qAajI1mjQCU${-!P&fbL@z*8L4&EsiU|G zq?a2f#niM*=Q4=dO*mmjCa*DIQy$gDih82I{M%=;hqu0q2HUn*Oq0SZZT1uH}qo_fF7X;Q66#u zdnkZ`%StLMtG%eeNc9%i`Acg3 z#mh=6s(>jAp)ISa_j_5Lzt+2~hA#fc%~z!q9T%^{tJ`|_spidkfwBYO2Gc!6%NXp=LHCOf zql4CQDg@ut!ZLO?1m8CVZ;lI^Hz;SDhZ+Bb4^7%>LJlnxgb6rHW(c0<4fB{Bf+tz# z;R(U>c|%mp48faA5vnQ*!JD=t{Gt#%9Ah2jA^0#0%UDeaUM$UcZ9@n?B7}ZT2;Lcj zzc~c&3c+s-!BhO0$2}qVXcLL+gCY3P{QGDKJ|={ITL_-!Z}Zp@f{!zixb62&+G7EBY;>Gl!}ritiu`z#Bl3Fvftt_9P?bGkj< zf@#7z-9EsAX+l5U9&N!i@tkg#EZD{2Z+|iSOA`(HZ^1Odp#K(36ASup!8Djp3E2G8vV&*%nE?FQ#` zgU57(M|6XeyTSdt!ExQ-$ZoLK4gNX1tG#{*g>$dSotHZ|ch20|o*PaSAy@;ANi#hG ze^Oehr$tRpfj2yXbx9u7r0;I-kX*Ms3XghP@{-chfAq98*wc3h^AHRH?GuO%>|4Lv zb@Mz>r0@2$KAW@#?17VI5I}3x-~teOTGl0PBdN`26I{1+pp<@i7X@wLWN-nz4!qr8 zz#H(ijwx|7_756Do|c)^)FDvpFyK45Ql*|@v0O{n`mvHmAf$a>fGWz zqy{y6y906G;Vv>b9>QAUf+)6L+_C=AM<8PcGtQmPyytyT9W^R0US6G z$V=+y3GB%Wd=Y#IcLSc5%A^|4`mbH<{y|KwW9CpxXysurQ7i8P?b<$-Wv9B<`ie7?en_2WjA{q)Z}T3zC*ONtsz6xbEMboAtKq{+H};xYoaee762W*DbH& z+I(CpY|V)!X73&k^hqh>dBo$5OCAsI!lZt?;A*IrGXQYw#qpkdJg%*v^R$+`3tJ1L zGSheG2ZEl~xTh|H2^!ISi}qC(SEjWlKDYqMt*})KpeN~vFgyZijoV0>ZkX`P=^mg-7aM`SFor8Ftt1E)oCvm<4+a)<0aJ^=HpC-Bowc$pOjI`W(MNt`M8<++ZI!mQ)|v?;CXa0ple z0a^PR-^z7upTUG&qK4Bh!p&yzfI(mXMj-;wetZbzf)sJwqI&}WqBx$sBNdIf^3D=l z@XSw!{DFr2F*Yv0>7V~_gNReSV^Gw8;%OPvJ`RJJrdUt&o~YpSsA)#mqtKCf4@&l^ ztbMLqRX*w)5fDQSf_@SR`Vm2jG$Ai3({)P)Mi;@8B{IU~th}U|u3NnnZ#m1+L~s@u zks5wPXFBbPE-rwXJpov>{W{7At&x{=1bow8#u0J(7R6kiF_6e9q+4<>gyq58k6arvP{<3Ef`1|S*EC6~uJyf9^0YR|=|3WgJPbA@4W2+U zgXI`v?XkS2Ytd5kdC9fzkFYl)j3yBl*xx$417VENvwK?mkHLMW$9tHMg@H4{QjBsSC_@88mdFRC5&ikyt2|Ui zgCv-WD)j6UCH)Y%P@1QLbt=`p82Uc*vsQR z1u!@8UfU5ZxB18!6ML{Q>s{A|C&8Ax<}Az{220HdyhDBGnHc7sgQ@y4_-cBqI^?p} zyo8SQ4?ThX+z^HpJZ(p9h7CN@KRw=qUCi(iHv>rQ&yg8C(w_X*;k`XW{yC#H+u@TEsJ6VW2W;krA7_A*Qfbho|ibntHAIi`5?K)Vf2izzbmT zG`|FspL5iI2%^jEps{I_AYJLxCv9JYZqj_eL7!pYsZ-|w61c_S=e2mHO)2aVs2BV?C`^|f#;86&r?C+5g zZ}II1C>=JRwYk>q1wSp2UF+U3N)Fe$!$v9GbxS4EN@45d@1gV_+*akkd3Hu%|D%%M z2E>a%P%lIW(3RlVXr)m4p?wV4%5o1(mWHzSAt+Oh47}Ms2OWSNXGt(%esh~@aJJ@L z!nKVY4}sDfMDPpdW6&4?%v6|^2{5l!joKN+P(@44)KZ`$xZpH0{CSOrMrM>GYPF`f?w+PZ|YIiimkp$vac$8h#|+axq0hBSJ= z8M6i!jF$iA&`>UuTsYy)||SNf#zW zt!umN+Fl@mmkUDlT}+`U9RyaU);WOo2cvzIbEVy|M!jfmhS?lEovE75@yX5O#e`be z(ukGOP#U4Fk0-fl7)r+ry_K}4`e1r%#7Ixz^(CFKA zKZG@J?*2TjV+JKO#z#&iIn}XK76!hu9cYki5iDihH zx>~RMICqCJb7+{(yJV}@GSV*aM)OJ4GvuJh+d=mlh{fSy`B^^{x-M(ynQT4X<+mqL zTpgH7UsP3kn)lgr_fJ(>`!*ODK{yr95VZfWLHkEA3RZ0QkKk)%;aPO$DC`pa0^+zy zcz+99aho+Y{TSv-A8KjF8iVUq-O#0TE_f)bING&+VJueS zDw1(-+ZXY)(spfcvb7zJ&z1Irphw%M@ok?0&y%=;i!m;7;F0!#rY8*}D)i!?8sXaj z>V&6)l-C-O1%oVwR#f>CaPsWjc{CkF5)mOHa#7WGBtFlWCY%7V(BM;s3)h(-u@i#> zG(;-u!b0f63D>HQDh!GTjwN6d0BC7BXK9vI-ACg zj2VZosv+u65XCisyp4#D!j{R{rKI(G5X^@6ccJ8Iae~@!QqwXsZWqEn*k+O4jZ*ON z38Eb{421mdB`5*gzL+Kk>u#-#?`NieogaAfYER%ZPxG-;Gv`cA-<^I4f&3KcJK`Qf zviiQ01Nb+A@w93$(SCtH-(SM8F*m^qNYmo4ZgkdumH>B-^RT0=X?K)ZCV9?P#|px+Pbg zy8c7|EVL1LtL-?&9md6bD1T+g^6ZAs7#58;9L3ExBlMm;>7`t(AUg^7v;~PUxeFbl zR66ZxZnI1OlFsG^R;bec+`zT=+^m=DH!RrG>7PAion=?TWR`mcvP3l9n9a-O(Zx3g z13%E#;G^7k0f!$rj&D1Rntr4W|rb3WyOnBj_D-3*vIg+3pqHNx4FAN+J+oO9t zEsCsvI+VnCn7hjjy4jw<`E&?E+e@J(CIz}>Sob<66tTaw zezHqFS&vMleL}D+58J81Z-bE3nsWk}U^2>F678U! z#QL}rA;~wxf>&ZaB4$aW>Hrvw4YWz!t9DQo+{}VW-K*%%bQjVl3p!AR7Ip$5ls{=~ zyU~p--@iodL%*uUN}UB?Mdwhv;a9cP``}iib_Ht1hTlov!Y*w2!M>>)LFahOF-8T`2L`?htOSQ(j^#l(d1iLWx>Msx4s~s8GFuo zftO(O6)*(M{D&`%5HQFd=VX&H`4h=lkZ@0dZZYVN-voy&z}z0Z19adJK0%j1M2Hv; zIYirZSXdjIny4BJ2E!pwcCQ*tRoo%BcdtsID(;X5ql&i$aRX2h;*cx5)`nTN+#%Ur zYrjKIHh8&1F6vr)#H!^EVN`1w@V3t}@;0vTSfyp1SuO3EiHV>kXG9V*BW9%5xasIW z&jK*jbuf-GQf8pUzoX*c8A8?`c>4F_cMJS(f&YJ6fS&JjC#Q{W7?00U4b1IMYe-JZ zq=Gv+tCS_z;aV5QF01#utIOPYaU{2+k|jQO9X;g9 zWz7d?vc8YQI5iv`!o(4E->I7I94^^qe)qA!LYx!Z_gmWj=Ru2D*?uwv?Yn;@}}0 z8oma-d#P{5cz0z(PNlERKY|B{yDF!qWJP6lN$Ch0H167*+Un~nN)3AM(vsRG^%P?v zZ?pBLLJq|e%R4L_h$Z&Iy#=kbses+;QIE*9UaGT8=zw6%n$f`T#w;; z3|o#W@awP*Jd^6Nld%EU@39s42ClStqeEJGIg=Wqq`nb0`$mDM#o5i+(HcYqc{cw2 z64qfX1Lv9#Iyx4Cjya<)amE+8!mqQhVUzldA38c|Ah2BSRk-eezMPSeIH*0+ZyWHW zk9kmA)Gs^q{*;jNbZ1nvEdRl2pH6Ez56;JNzK8ldi0@)&)U9&9GyWFE<8-&G7duln zXt~a`8}%v9%x2p(=MhOh9p=mgD%Y8EvD1CAGaf7#JMDSlnP5MQb2v7GSKvyrZjnLI zs!VmdH>hB^Q7dq!)rC1z@$YbSt+@Vb1tG zXH>4!?g_urD(8nQ5t3QKpC@mCjDGm~3SV0>R^W_UgbVp42)iUWCQ#C)O=fa4zQwb3Ci?1-p07iP;mxvk0tQqL3)}3hl$@N6gKrfWmGdk#?1EVn_D_ey7L`I|Ii#q;oBeH-^qdi zyRIBI?T+OLhjmepN47X@_QR0=|K6wh|JL0lrU`#f7T1~Lx=37W#Fd_Pq+_GFJ}9o+ z#C5m0z9Fth#PyW8;-4b<5hbo}aZMA~$>KUwTo;LJjkp@z|NBq#|Kecx4yNM&kEcu- z?@pP!7&mqG?)3E3jMTId>Gd2;Uo~n}YFfq+fo9ogg;r>mbNM=!vC$$%kI*s;2VZU( z#V(&4#?n6oO5CxNE=Qo_;Z=Mii1RHd5_&|Q2SR_d7Z_Nq-S!6rN33eQ6_?3&dRU&t zO7=czb?9^|%eK>0zJlZRekjMv+9pWIRW)n`K!SE9!YA@teAqViCZUlLXx?^P?;I*$ z3{uA-e(U(WjiTCm;IyMcIDV0_r_?aRyMgnzIPFaJm~ zxAy1d-bC|kKVFup-Iomdhe@A(0|@U-M!N}psh^UbNWCFM@+nZiXHWk-6w#jrnSHeF zT3nb;D`opw2Qk?7>(Q2d90#KGM5@W;K)iklTCitvz^zY)So?Smr0DBN+5`@y=^LP+ zJ(~lW`Yfu+;lO0wO9Ccyz@tA*t>toHrcVDiZNHcUMfyv`lh>1;<5{GSAZeGxz5}3K z-%n&7u62z*m1^=i;Mb3#Z}x)l17K~?7gEiP9^}F`dQa-lr7kakwL0Y|`(^eIAgx)S zN<6b%UjevT4-zmx@+}Z-)UTv=7WDcaHTUSRlYpyw{{rAa{XP<~kgM~k{u%Y>YDXS$ z+w>m^C=MG6V26$`l9_!G*KN1nOzJGQD1i<3RKC1YEh#Ox{*SJIj=^wb2!8$&@=GKPeMjic=#tM z>(jv+9u@I#P%!;%kc7v@xp1-5I9-nd2HRX>T7&Y&h(7edp#BlWMkGY8=1hZ78<7NLL85!9Epr1aDYSJR_2H@6Dl7LYT8VE^x4*DLE z!9}O&^dHC(S&?MSH2pDrM&v{v$0bw0AKFAriaG;evL1^*L=?pR9dbPSeNZrBItPmM zy(FNJ1LgY9)Xw}~)Pn}rhdN1alZIXb_v;a06W+x~-tZi-#7B^34kdObE;y!^ zl`^TfKt>7V@J7`15lGJ*PRIt#hLFA-c}d0C3?}vOMgJ=hspCU(_nZQ?lc^jb4Y(F% zy%<)J2J(<)S#@dXE##k>8}adhz}q|?He%?G^9us(U6|JX=t<@GE~FH zfi87yJjhw{_-lbpOr(c^TxAmP)#GM6EHMdz<&ss7K4`I1Xt8Nmf+cd89Qh&PsgJC0 zKSDhfO}YF$VadL>H=!ZdhItmEqQ_;4Hb})50vKKamGgcselU=R=G@GQ5vX^~mZGRJ z57luy6J!YMyBa-!Ry~!N^8|SF&yzAWa^&@9B`b(PQNpxHqPox1(*1i98B3iA zu2|Vg6qIyn;hUO+9>mIhDbA+?o`7@k05lONyQn2))2XJxPogJ1G*$ya_-Vkz|n3D3hc)Qo1C~l2@a6xwh%%raQpudJz-B1R+gg zvw-@_u3;@Kzk&+di=ya@AiVl6Faba};S`*Wf>U8M6pq^`nVk24<}TtSJm;h*qaW$y zGrdeAfzC*e8gU4ZW3oW00X(~_up3JbH8 zU?A98Z8b?Due}WVgE$SZl?j$muQ`;~biv2C+;1$oW~ryB|4};xxrHQGfAFg@z#3vZ z1ji|$Bx$aMG{0eE}}u3%WW~&XHH42{X7z%4H!^W}(&)pjvV>hB z$G?@x*p+}T$7%l$0Q!F~<2Z+wy|SYb2stGQ&y#_K*1vK;0DnT^ew>ad0J04qF9X_( zGlKZk3}C(i^aW0JG5~stgPzJz3jq|PFpYqD01lwAkAOu0K0x7Z0xAK>(E2O^4FDzr z7==?^2cQ{+I-K^uAclP6#$RIb==#4TB?~(T^0TO+%>Ob@r9qI;3`>!PG&9GcypnG8 z6rrMR+Q75TjE#_<0x<&+tm-BJUxEKKoazGr;*i^;ajK62$VGt`MfSG<+|t*y+f;CW zFFDr6LJ1`qHUmi=PWcSTUO?el0%%NV$cgO)d=ufkuwBL127Q)zNuvDYyk66xEyDC6vBK7+PF`eiyJ|YOIjzP%}vP$Bc2;e>xZpVp!9*0SLnY${C7L=PMjh+jYb&S|uAbJ6(d=6yq zq3||NWq_dn$r#hd@HEoAqXVPMp?J{<9g4YPCGeq zn^0Y1uR@PB$u!cLru zTX0-BN}A9UD(~K1(D3m5`4>#iIVLQ)Zq<; zA-wGr;-uC^DD=Z=jJYcZ7~~Q=14m4kY`>UHBS16EWXeTh5;5^sYb++SWVU*#WSi7O zv^o>?mzn}fQCMsVXcPiuW|m9=GbP&^Q$QW)uQdhSgu;3fprnXC)18hnE1R9k5vvP@ z*)Ut(Eoo;>QFnp!4pY>VC_Ii+nIY)S(4fVku})ONf(p+d3#{1S$*m9M{zV!U(2DX5x=n%!{nVb6o%tePIS|aRvlsu zE0SBKrCN49vW?44ZDBwg7nU6hMV6z&199~K*y>sUyHMDHlPBCXx1eQo%ez(5=>NUA zPM?G1C{C^ueN%lNr?NoMoA&y(PW6^fidPVuZlZOKHbmqVLe4>B!hhOa(Obe-3jAuc z_iB(0rvrlsX)6%c;;vnVK=HmuFeP7W3+B4FnJdW?SdR=7R{n%q6;cKg*htSY}q}XsQPUH!w`3vO!ks2Ba;CU3D zBOnpLD<~Yq$&^n8cc_nzh0_*mgl~sjS+-S?TyoEGaJJ)Ab_$YE_sES13D{vWl$MuI zhBe6K@<23(1347J+9i$t<(ZpY=>aTGZt@-|L=iwHPex$?P9;t-3zPFZK-fE#+G&JP z-oj8p#2D?deI;pOy)eYdsa)~|2+A@g7ot#rlSPqk>4HvVsbWv@CK7}3knYqK%CoAcoL^_W!D}?Z#vs}hg!!5 z+TBTZ_n_4xz-FSuk$fqd4m8!E2eig5w%RsJhqgh?ZSM_7+5smGS-R1gBc-8rTAJaW3B!y6U3(;ixUmZ{-Kk5XOk#)fMY`F* zOCdbly6}wd#=}OLJf}iRqqUuL!dRgQq04!UJk-VSNdV8{G1 zT2LU@^Wz=B{fD7=jqIW=>pt*dJe{)zA?=ni`+q})AaV>HGFDaC{*WEC*pmN-nyFxz zgfoIF)HYz20;<5Nz5<{bg%vnm@lfjrj7{Sx1@VVUCQ`n92`c^uimf;u0|Ch6@$m_u zcX7(803u_tGlnyQ*wo8`Ne46xr@8>Zl_<<4pcKF!6xwj|?VybqDs;$f4LBcgybFKx zjfJ;CaTuqu1Mrw2D>mZR$OHUtK^iS2EfORW)(*I9(gq<4xTF)1@|7iNjUWr%qTrr} z&&=e?HG9#LM|(e894k$EKSS1eoHA`x#lh?`IFU34g@+KwBKQvslUKT>}4?i*>T=v%mmX|oLreg6bf)E!$Jg5aM|66YTDCD zv*dgLIge@FKnuV!AE#mJsa<(#vqdQ36u}`O$HFWjMJKH=(c|r;;Sd zLoMx4-T*xv+}w0;O&Mi55ZV|)Ps8IGh$h8Rq0JR|8fae%d>v{VVUtLG~|c&cnD8(7oKdZ9fN19#q)9q&)>W7 zOzy_B-Qo#`@Eq^Llh=)>&E!$agd@e?l)7mn-9xh|ii4w~PPR&B$|;Z?iz#!EU+;n& z@4#sUhF#>GuKs*nq<7PeT>JyptYns;Ge2r&N1G}BNr-;j62DXMbQOP3XSnbQUyB6; z9kiSN1tb|!f@p9ke=%p8_yK~5z6a2z3V^;U=!GSz#X;L=WIb#IE)+EHCE}|pgT~nT zq2*m8gsco{c44RxhqB*nR{Ki`(b}#=FIZs~E6hlngZCOlq`@BqnzVf0N+n;zvOi0i2A6&5UqY0ycD^s+a?fHZB;A+fLV*;lq_x#8ivQL`T!@5L#% z0Qdlfw{a@N1wDP|!!;0#jq5ErW(mWQA46oVm1NRl;|xUnKw|C|_nR_zbXu4C+hWNyOQxP9WNi}Sg?sEi$h((h zKJuF~qaesQxD7wBWJX~c$dr|vOgYa$##1Ec=x@rY@9Zz-pxba`Mu%*bFe3QTU)GM0 zOv*w3guFvIjU2Q?@N`WQZOdbr@D5c;Vl_}{-N_#Vre+FTJ{ z^le}fWQ2s+sV1h?i)#8Jab;*Veb+!KCVq=ZxpyKyHB&Jfs+$dR3|FmPXhG<|S?GB4 zdy@LGFi)T9_@M21H$e(>5}flbToEDDMt$^nG!P*tUJG<|JV&QendKy_*gll@ppDTb z_m(w*1(R6x6!d7l6irq^H*Yk$NvhmrG#~vXTDe&Y zqc2p^g+BQ74Xe)*)~^q5Q7 zZ6Gfb%|C$VmxwZb1yLTNR)tY!86!ZH$mQNMh}$K}unXy)Y!F40?t=`sgpoB;jDEu! zLqzEsl(?H=x1k1Cj|P;&xYd@U#J_P0qc1OqN%7QF*aAo$Zj^h}`D)k*$xdBJ3A-0j zaJdvttoSnOa(v8LiVsPbBBO>74jbl>Lx*(#e60OapGQ%blBsONYWpV)-u`bvQJd6* ze_aI}{ahJOQsqJfTn(;pNgF;`B)dO>lOe|>lG%Nb6DcXTm~vi# zj2$?&>xCS;r>;dUGJ``|Y>H7LhJg`>b<1s4F*|*Qy>zHi&)5?v^pfHZQ491P_9u}3 z5zYvz;NQ@m1oSmdbdp$aw^&~Us*e<>rU=&FNH)E2Mouv@T7%q+p0 z{OBv$ibxDi<#WJ5zxb!TE6BUr$^#vYgUbwBGR+ci4;1l|?Qv6PC6L#U%wK*}W{M?K zOBRNu8-hb%r!@^WAh7@AX4)m9lD?e%)~GaA__Fa8uZQlzky7sxbaeMX0}v44 z2V;eP)>@)H6~Z&W3s3jEhbq&}+AASEtzCFJ?;dFP#G$OTc)kqbY3st%dG}!Ov|2ox zup0H}hb}ymyWI(N-lZgl@C-q@(+1t|9y;$*azl6)cHx0##T}Gkx6Zqi;t-xqU3dz+ z@mO~$V_p{c*t|i zNlz4wrQ3_oXg4Z68aZD2ESc5>;gq+f&oU@)5BLJ*qkNVfK&kr}&)mK7ffN6xle}1yN~fDO49ZzD|CR~6WVmx8 z(Q`4S&>{18eeF<+Q-x^SV0_CI{WS{5N%UetZhj#_`}*`q3`w?^BC$FYRU{%wI0HNar+ZMCo|Dc_7X2Y@A^y`7sP!0~GcF&B4TESI_#C?Y6iz-A8`5Yf9s&3P&Uw7`6})vB{`n7Zz4-QlW!!zD zgstLqaq``R;V29xfOZcippb=Ac}sAcy9fM>5AMrZcr*o#k*=Rv zjhQ83V0L|v*U=b%yv?jqEb{(BqblWj(unSA3(%$qr*T)?A>@kPZHJO##gpN4DHkd_ z#A<Py7AcLREe7LpVh=hv zWF?b|Q)%5EJ`@tA;<1G7y__xL85h1T;l5^0^4CQA&!$Ygn35l#ie5_hZ^jZX=6Qfw zVU+XB(9cCUc^)3mU3fFVwS=Moo4_;wR)CKX>Hvgh^Hx3q_yNv=C8I~NiP0pM-q-Xa zP-&5(RoIdr{ZZMwgiYo*lA~yGY*MKFG1|Odd`X=O{^2+yBEZGh#W{c`;KWfhT?k#r zK{VWu#xCtav(a^+sKz<4A)Q^uF`EJ1h12yA+PX$CO8k9Nvm#aiCqIpkJ)n3ICmTg% zcM39@!7S|Z2u^Yo)gKcHCQ)i?+^sskJ)xH}FzUt-Jg}QSgdjFR6a2oE}HjvLu4T7h_iq@C}Fqbw~o$^{E+Ua z=TGIMs8aglX)QQNJ_#Th0KG2YN@8-58~O1nKBWjblG98>eJ)PUU4eoZr=7ZDGRW*> zFgRuH9cqLU!c8DqkJCs9LxkY2+g#;BIm|E0Lr4Hp5iWC};AlY}k{MW>x>*imNmszP zGjJxo4m<6`N%*07UU9QVKbXqFH2#^+k|=mGaK@!>9{#`o5ec-?Pr(F^GwB_4>z_E| zw9QJgm&qj5eBkZ5~A{PrEU?*}{H{s!i@F)u%E%mU4v>MIDcQ)4p z&D%qok1^|cFGSB{lUR88#Bc?3Sln1x22b^;Kp*;Npzbl z$B+T~rb?4~o4p#`J3l77H=+AMJnP@ZDngBmLKp@}#Ni&GWN@a0aop7I0l18W=?9qn zqS1j}lo0X~EqSngrvgc-Y*O#+-c?TWOZ0>RP$xM?Qk>!8+_?knq~#!P0JJpNY>E4i zY?DI&axesTw41?znSBt~Z&(b>mW}o0aFfe0Gq^@r$^cd6>{KV!^oY@b5sP?!Cd6%` zKWX6-d7IQrC%2g4Ho!rsQI@9+Jp6>b!qw#AFrWv;Q${L!YQ)2U2*bn@agQN)u9bRI z{Eszh!m|cM4({z9z$wOcV!1H|I@9QB7>30FXMPM>7Beu03Xi%0iRNG#PgbXGDZ;k zpDQBU&cCfOgAq?zsnW#oa12Ppg``3rW;8aH+g*%ta`q({m?3_iW_gi^lo8P$(KXIy zI2I3?cBQ;YLWilDg0A>{-fOcX(?#=EtQGS71U+Vufec$+O6dttl#;*s=npqj&P*!= zI`Q~=$za9I&+{6D&E~@hqk9gwX*lJk!M#H*I#*O6R;fdwRr9#RO%oa6dPS<0z7Q|4 zdyx^U_<}!GXlZuwN=rwCDc2moPue*m;e(&xJ`MlFb~nnBB%OCg(gkCV;Y1<9=0bTq z&Pg~I<7~nCEY3EZvXp>_R?s&l@sFc~J3${m(8n+ILAwM>AAHCI@u3c(k2HK}nN%vk zhg?Kp34IW+HVY-4RxhlcaGJ~57IRcpJ9M;fv zuCTIP(j;ApDB4Myf5=Yjk}7Lq%IQ_9s-(%vEs|feXP~gM?3k!15*0-PE-(=$XMtZ( zn24fm69JpsN5+r|IPN-RrF>p zlC$bWT@bl(ZgnRaO*XIg7#+1J;!l*;*ta zxJ*aUbTl`C)YO%ut0a=7Y?iVo=*_&5BUU5!jfRqcl1MX2d3d#{zMa&Ul=qorf7e|` z&Z{rehKs%jR~4W(O1W2BON}UBuA<%ySMHH^YJ>6=MKAin5aC6B zbo@AP?~_hV-sSB*1fsFJvURn+2an~k2=VYjZR~h0yqtd)qG$8L4?XAWn-8f9lzha} znuNgS?_o4bdd%9?RJbcsi&ScXHXO9V>L975=G4WCDC4j)dR^C}obk#x8ajJke4+!IaYIW9V8ZMFSTg2ZXJO9%( z8o$czAWba5Q2bgd(BgFEwp9h(9D^{ld;e-sXBHp~zm`&==mO&2XmX!2xNrTRxle&x z(XS$_X-d8(<+wE^0bX0p-Sn2Fxt(GjeR>jp>Q1Yy+{mfl?WF#{ut1-#a*H(l@;SRqDZ(;0IM6a%zsN& zZdZ{??XawFe*$@v2bF$@sv@Lr;R553S9@trUGf*CQpXY15Uy%Y<-Z!sD9OMQUn(@_ zsCZ)x270unq-hcN16(&>3kP0Xi}FJ-C5rFrOSroGxOB=A_QxRwa(6-sWeG`nhZT|7 zCsxZtBp%&oucy8d@mh)-O}UXb{HGW#2CXF-BXD=)T44JtyOo{V=pwEES}h_QO@%4D z8)+a%pn@D+&YOS1YF>DA4_c?8nXyqzD$p`0`_^;Hoff63=suw+bgtgaXSNg=wsOHiPrj?H}j7-sb2xkkz)!;}N)N8fDCU4V1}y$!Q%=W4T(-H`TxL$NE5a{&=2sEW$q`a85S|35^q^nYP%pg(vV1z=gwY;l| z8?{*7el<*kbaMLoY&#`Whpya;OQiBHbutF(4;9Jwp4HUBp-PaQ*4#x}SQ+d=xkys( zwDO)gERQ4cQAZRzy8qvrlB0#+0P^|sk!3`N(aSJ!l*d;4H6=bKmRroBzhHE~+tv9Z zc~S^*D7~xwL?y zvX5ZYZ1d57X?mSLgZIlrEQXM=hL`y~4Q@)Y#7zUi;&v$E$Fz&c*(4l($fX2K&FM;( z!9R@)07n)n+S#6A$0=P@%rf$NHx{j!tG3HBIpjwwX*;<>Gr9O)x{@a8MKEuH{w^%1 z?7{c^o5`~eAT(XNgy_;e`e_Y>(@75<@d2_4(i1egkySP}+I!ez@z(?Y;>|;92M=@n zrY@rC0g4PNAsO&mj+T>qkf{vv&+wZbAec=b^+;UGtBpnW9?ElQpD3`F=)aM!FD4=1who9q?%}+Axf|%xo!;g=GUV zT|bzLAXje2-DM17Ct8bcHZtU$jUo$vDX--|zPrhc2ftrb1!c*eLo>Cpl$)>hx&rGD z&Gnq-$kvpRnggId9ivRMqd^TFxKY#Q0CzK*hBSRE`l2jGm2wf5Mfx*9+OcHQ4{F-K zVlWIb$7sFgYtDSFe~}hfriB-1Nqq3qsy7LhXeRdJG6O1?YT?S5&0KzjG8MKnybLqycWS*si8%o_!kRQGTXT|Y zbF}_IM<^NDm;ei5Cn~v>LYR^rBMOd(5R2OLzbo4#u|6ECd4bG7TT5JL4JI?UG!r6>v--p&nwezKs5wR3b8fJ8i zxMFld89;qr0{<2vJ8#?!Qz&bZ1vr&5k<84UXv`kUDQmHU0PU&#BHQ2PLB6($-|u00 z{fkU#9l7Kq_#qO^;rbuN^}#Mz7>(@5+27`4CtP_PV~5rqaJZQ5Ub8ykf0xn?y#6!> z1q5ePqE*U_dGrg5jJn!aC(Q?O>cf(hqopcO;D!_(RkndS0b2C}cMflRQ&WMRR;&G# zZ%{WZl_PI!qQ(+|%GOe{wQ;F5D~`MZ5)cTWkd#)dMLyp`^l3vcOsnt?#!wvOfx)E07aXSn!SK!scHPa2R1FDeFD5D(xd1|eF@9NQcBo9=8a z{tBd3%>MnA{4!cRP*!q7w3)zj4QFYIz>ULp65=&m8-ZXWKk;ORmo8WHoL1 zVE@93VQ3MlJj|!~Fy(_LzG4I6!x+qx&Sxjgp6CEKndm9zQ7+}}E){6#1DJqX3VR+N zIoX)XCpz4w8>=_`_f5(%BLYq~@@8S5O%Q;EsJ7l%ARb{-Gwi6e^USXFv2Wx_ zR?|P{mVF~sUWD?5@;JBUTa7#=A%M0T=GK)r`C>)`;vjMA%Il3-G4MV?l?Y{secD?mI9Qz6vuUBD$kF!`}J#|^`*i>cUfG1Ukc|5X?y zU>VR(t5r0663%z(c5LNbq0iSPC9h~LC2Y)0+~Pl2H@!TwqhQ{GbidNbTHRCx^Ri|{2B0soRW!{L9XhWW%c6!lm%&?Fm%Q5}vgGTfja03v6; zoEU?R$$wy4$JRnm5GEFBJqxgZT0pxkJqxr6C=~F$RKBk|q0HP-&DOB_0uhKIWjAKk zIAxb1g(4N~-L_vj@kkr5FEYZRbB{qDBm(GfilnSGb|7B2LU(f!wPZv!nd}S3L*@)+PGo~` zJ2mQo?_wpXxR_qGQH<9w)Kye1DPBe|_9(_XGiv>d>&sXvfBnV0l3Mz4&nn+_v++^^ zR-AuXF}>)b*w0qZs9VD7{I!+7DqQeBl3HF1)?#{lMscy%zoN!hyqM@&4M{1!uB5WK z94|q@3rLnV@K+0Pmg{OO{65aKtfZl^+Kcb3i|LgbwZ6LI%8EKa#1}8FT*}{RQe0M9 zvZRjHOre*fu;OJUwM%`qb!csIeMP0eqN=!zO9P^+vZ9LgteEXDsr3sDs;hiJ`yd|g zTVUwxbyM&{j48&OI9P?D)3OG7T>?Z@R8^AzBAMy0rCwIo`+=tyITYhHBqeAV2B|Aq zZm7Sq5N~7PD0*{q@6!63VtT&|_)C4TNpb${%IfQ2$zpodL{&BLk1%bPY{ zC;@LQ!Fwg}HVH5D@-6|xgQsey@d64)-eZerFp(pz*S02(63p zN)42&$+NhkWnalM;V*jKMlpYz2lH1f^VQX$YhdHyR95UQUkb?dJXk~V(AHD^9U7IQ}3^1b!1Vfi#Wh=sPooVEcThUtbrMgh=g1t zg1L=4{Yh@L!ozaiau^$C_Eur!K=}}jt0>NROwNGamQ_@Vp-^16q7LS(EUqKj2D{-T z^tjf`7Ux%?GgWy|kd>Hjtg9*UqR-|qqK=@k68v*Ezl_0UNoB=iFT7jlt6lCZWx{x@ zcsYN$5Bxzdxe;lgn7;xC0x&MT)reej|8f#iQA!#WSMe94uo_=&A;uT0DXgFu#2^Wj z5%pEq;XO(txC`n~$-6nCrncJe z^ZKi6M__bN!Z0hml_hoMW;NbmgbHt22+WGBD@#pRQZUl%Z`jI<&HxTd*TRC&d?3Jv%w01V1h^SK# z32HC-3wH4?L(}=Q7gv^4EhT)ZZ#i3e87z1Svdqd{YRd>Pt9LX)(QR+d3MA?ZOk?%6 zHI?;zHmO{S5eoU%OOX(2M1ltDYY^qs0A9+4A;2f3Q6Yg*hTMTH<*lzR_EuNcFRNlD zbsk@XVMfviv`~I*23e%Os=m%w3U=5MgVadzR<5Ch0Bco_q?rn{XR4QtEU#YX8(CMe zynbnUMb$`ueQjO!NQ~sVYb(WN1d5nYM{`a#vc9f%WJQ&?vcA+e5}{W-He&>QSl=*W zNmc#G#T7_%oCG39F7bMQO_H8&vi~n+jifnWNED6!4_ruYuw`jQ^{CWx#2u#lQWVNy z{Yt{Xue3J7i>Le*WfeZE14XH?3p*x;1G@w zhFocmnmbLJsZJxV2Z`(3MBu*VTo2+0EQt@#@Zk{%^Cu=%{6I42UP3|&r?5CzEz}oOoGOGqg9U|_q$p>En{EcYNN{hZBvzf zDM_={59Cuaw&{B?yIFCs6sF+mCVp(e($sGLFlVl^JPH~}c2hcoOr;ZySC+@}dQ%dE z;4r5pm?_KEEi5twF=)AZhh*>Q<4Ibs@xtV!fy;GX%0$V=O9QtsXOB)H=wfsTIA}Tk z|HKd>St4Xh1a|XL2@%+6vD8wQM{HsE4#`3cS{`8viL`_`Eg>#bh#uC-qb#>?VU7^Q zpyhT`h{F;RW(f&5g-Fb9{laUQ^=qv#oJ_H%pNTb0uddm%WY<3jxP!pvs z>Q>vUeF<$4b@|jUk26`r5EWrJkbm!z?!`_1h9>m^{Rth|xAjdB3;2NTJ)`iR?FTyu z&pTSeaQSDS(|zE!&O?97lKPsa;1>|oAGD_-5{01tlTmoY_L5DyO8w0Cg-x2tqg-J+ z{o4#3D$@tntv3Ekq$b|DUc`TQDbZ~}0~6!7sCslKmH3~M-=Ml%dPGv}rN%%AX<`Tl z?w?Rc59snS8&Y?-q@X#~WK|!LzKkXe*N#5)qn4@~Or%|dA+r+0;%fCr*VEAma^uj= zA7xeX%gHmh$-Sro0iei9{5Srm( z$}8eO6t<{23*n??k7La8wnM56KBgbXKWoq-a>J*Q+}$@jpN1|QL{>8>%ugfmrr<$; zI#d=Vr)^QA_QJPFZ+rDaIvP8q{{!Iy_#d`UZ1|6xPi*@lIJ__7qX@tsMSK$hvTq`O zjF4ujKSrF1Q0gXcQO`sKB7r{a{J<$voc)-aH^ZF5ZW7aztfohT`1QzFoVc8De(RL8 zw=g{GE{+hJju5ZgL_#k&v@v#c2bqc%(ArhB>z zxTmSxWYfhd3#p}LZ^?9dOC|wpNWg7egWID1jHpom9Cb2EDppTM{Uw^n{t``ktsxiK z%^%B+fdkVGC6Q}|wt9(WkHN~O9Zl-sbVy5o1P0&Oq(19>0k+)OWQ6&2`j-+MDjOAM z{U9IGg!(-8IL#mEL;3@W-bowo;y-0#3_tv=9x=S8c1PlAK`xu==FjS?>_ts~aCRg$ z_=|SJVV2K^oeMY1H#u*0it=v@;7%iMULaOpAU^IkqfLDBhaWA77eX|-4gUtG-X?vA z(fwYN`kk~!2K>Fg#)iup+g6*hqLI?&E`oR2K1cK-WqoctZtDf(-J$Xf)X%WMDmE}t z-J-s2|JV-VkL_pe_&32O^{o9B2jEv6q$W-xqah2}7<^Xfiq(KAO+`Ovt4r{V$4o3epOlmud5iid<$DG24Jri8#K%2!Ic4AMux55{ z5q#*){f(lE-}fj|6UT+~3(_kRo-f{_z9PLx@Ox5FQdX_nq6Ve!3I1L>0~o<}Mrx*I zM6-OO%;6hleC>I)dYgQw0pBU(%iOEgr{#|++ia#Al6mS0q&i&A+JCky%Nw?+KimJy z!Nb*13mrCp-5Pg3Fp5@zsH(V6!D8lRnmcR#Q0-^}ci4(V)(PG*= zPF6gx>iJ8IMPs=mR-Wy=Dlxne-o(PR*c@}-v$ObzBQ?AA0<(KuGWP>g?M@k_=P)=73Zr?2NY4ij{2CMdinucBPhXJ#KCYO*`&_$^61vqN5lfi$@=YZ%&7 zz8HwM)sYy2;h#;kpj}rdDYl*Gr+l&2Z0RR#N<8JQqYCWtIDJ8resVzx!u22wwWVd< z@1C}4q*K0FMK;USP+|zyem2n&X=&WQyIl2rfvp1fMnP*QlTAyp;a60jcem4OG1kz> z#3b^}lz)G89Ef$;8caDz)&>1!&X!9cQqwVdr`LFSuC>@KuJDt`aIFgf;*p)*QRVmI8%hxR%r>ICR1}z^{UTVF`zeABDfg z5)L~=G>_OGeGN|{TM_ocwMUwe-CJJ6&)sVhw;?5LOkSj00)(Fkeh*I%HE@WDTO2p( zz0!~T!#>Z%_s~y?pY-A)ggv-%J#Qo`kY>my1Iz-5Ioty|JnTvWQKHRs!2bY;U$l=N z!edSHUIL9Oc6RYL>*j>svEG)rC!xEUiA5pBN-fraT(&9A)t4&lb$3&v-YPV9H{MCP z)g{2btGp8Ejo$Z4z{o8rCBpvVa^AotZ?HvdO!xq z6K5bMZ?QA$m);@PY_TgFiepo-Jj6f}L$!7GLqLp4A2>4QtjOvd1+PPnsKFs9*0GzH z^2Io*9uCeFY04L?qMBW-V>a;J=lK}zFBqn(#8{FlrX+@wG^H9Dpd8NlWy%?|)CoBQ zme^)ULW1Fr6TZxa1$zw^EH$+{z=;WgnTV2Npb`%cF*&Kzc4j8=P>XEOlg?CR$`{j* zn%cTQTg0LWdKV zaWABsAdPOXv3N{6!)*SA88AhF#$nYB!vcIZ+~AH=XUcIZSe;=u|H4Mi>VOFAEXUd` ztmqG7534@qcXyaZ(KkX&V_~?pIh^A|v3f)>{kMtwcjG0%fy)eMG_nW%t_LTuo9POV zX$~5i2zv*f${if_!%0ymV~2J!M?E|#Y9-c4D>>?M$yAq0Y|@a@N5Av1(x2BJUZ#e5 zAKgB|JB7QwJFo#TBgL$)B24^NBh`k|wzc~gV%)Uf!9~|Bo27DM)`-;oe6oDNdy>6> zdk^3#0Ggy&RKj8*CFhkz_%tyVbQ7L}`}tO>q$D~e<%}E^=N>`vLK8Xnh5(7n8dWfAN-6Nbep_H=)RE$qb%(j~f#S{-tR?%guPDjh0zTOPi z_2cdV8zIY}^JLawf4ycLWE*7)2*v&kshJplSTCPNi#Fj=#uJ-0O*0j%@}i(#C>3Jy z#zrN@Xbz?sv3=w65cOVCTTs|7FY+ndi{9(l!AU8w zqvM|a8}Ep_y>%n15w$|c)alxDm_n+x=>4^?bE@mPE5Mlp>Mqt)ha=U9d>b!k!%vrc zH+udEcCI&ikm*`aw|eO7YZX?pVvt7R2^U6!tz+`SXw1DZ8Z&EmymK0hE!#2Rdn4Jd zqsC_R`R0zq2tG3PMQ@!c5&lE|ViY^V#~0Vq(iI?Eu{kk9wE0|phd3@2%Qd8G(I&Ms zKs2KrX_;!OP)yd?)oUH%>spZrGsA_VeMqVXH>aizHE!$BJ;@M6DO%7-Ep2FZO;PAX zViko{Jt*p%LD9URsGAiiQphFMj9fIuCOtAGHkYfiyjC6wPgg!7SG_ zt(ad_u*-!L6d?Sb3s=xZU-f8j;s*)3FY#c49!or*P(j^%Q#)+psd6e!ucukFY4M5F zxmpExxj3JuwyB`oJ<0+~*HCu83L?)Q{m@Z?SWe_+LVAu0uGJQE^g>aGrHZH&!2(ZE z6wbm#PmL64I%OJE@V*PDskM!A9V&Rpy~az^tXaAog*SV9D(FGH$5TN8`kFw`yVq@X zRCh~2uLiq9dNh0q-^PeW?M+-)Htv`H<32qF@o5#vXBtRl3Da31%SEnPkfy3e+%>y^ zV~8qMAj?J5+A5H>As*dSRi}b~xB^yD^KmLT=w6;xP}@0X4&SxuV>XFnihmKyX0=7HF4ihfbET$7qqAtAI44)6_<;p{q^j77blnVAbB2@eP zKE3U~Ms%R^Z)gHJ#Uu!T1_7l`QNb*CKD2~h@n7-jPyS!fzH-c>dlITUwOMkmp@t^< z3^)sC7>bf8@6)?WF^;h{wIhJ^G{f zmPhY-A9!?EVsC<;NW6|gFH#3M6{qS9%~?Q;x>T@yj+BOIMguKDwYnqFv>Iy4(yCRo z>Kqkph(y#iQdbvk_yYBQiAFC{fxPn5R z8W$0F7}uQsJF#>=0`93gm5I?`4ZPI-sz*kis#O~oM-sP}rRdpayQgNTwpjXXi3JAv zRQkB$be3uQjLNJj#sOy*kuLyb$S1-}~OzEpWy%HSPk@Rh*1ed6U=Uxt1M z@NA(N?&_K{^tS_Vaa4tiA2WR-&5cpy4wOmfS>Qa5I_-UwadfB*z1ncz$k;JC))wgL z-JTzV#C<`uG%63T1+pPR%Zdy}<~O2806||9A5rL-tXtqD`1nKmlUZ{&~h0yb8zGvq}5IY|j-qAM(5m1(4 zFF^rCh{5fUDIVz=hce07$T-vxpyiR@x^rNFH60=$Nc3-qpm&5cpI^DYYxRo!iq9|4 z=T(0Bmd|&sK5ZGeAkv**(QRY9m#mjgb}e1CBCjb`OHGOqc?p57gOEZz!Ks-^j4a4MNc%2g&dpR1P5K4i%vS#NyFgkA^6|fx4ES20evsN|X0P4?*R& zjt?RW<3lAt$6N^gENCVa1@ed+>D$r64EAbx$C%3Ly5?}1+ctzCU$9z5ZVXeC%IV&o z8yV7lKiAi9M>JFdlna?O%+NL*a*eEd2DT$H!;TK3a+p&FFjWBS>)p8xls$vc>XG2~ z!SjdYx~I2y1k(PNP%A;$YqVe#Mz1S`ht6aDHJ9MXAargx{C_*%fnv#mwWM(w%=hfX zx7$nu;-6`wZVmz!_zC|K%P+ot{5uwTt_Z+~=~4cE%P;<1A+)}%BLwFCK$KthDdJ~^ z(E4%yF984Nr=Jc<`S;uW#h3DjmOozp9^f#Y=<5HP={RRH1a35yBfKYste z4J^u^wS3~AYx9=`Y_!Yf-WdPo)rL}hp)be3KgPey@{7OrzvTa5j9=Eb;>*Wh{IXBa{*U}0gO~fR@XNls&GJk6dH!H0 z`AWaq2aKyI{Ibu^oWj)j)UWZ}zwnD6f;r0X_KBT_Mdm3q_1~X%le|D|q>ooqt z*bU1X*!{CrSI&Q-L(D~{)A;Q%oN3vK&jw-T+QgXuWu3;io(bSTPRnLw;}{cP*~$O- z5CuP=A}&sQTp9kFFB_&kvHl7>+4Tt%kc+~hG$0FnT^K=~}(TsbZH2hzWF~oxE EKTKO&M*si- literal 0 HcmV?d00001 From 8604065adb101a44043d73b70e4c838788236a93 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 12:11:58 +0400 Subject: [PATCH 090/121] allow any neighbor compression types --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 4 +--- libsql-ffi/bundled/src/sqlite3.c | 4 +--- libsql-sqlite3/src/vectordiskann.c | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 477758b4cc..e48ad61292 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -213594,11 +213594,9 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ + }else{ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); - }else{ - return SQLITE_ERROR; } *ppIndex = pIndex; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 477758b4cc..e48ad61292 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -213594,11 +213594,9 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ + }else{ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); - }else{ - return SQLITE_ERROR; } *ppIndex = pIndex; diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index c6e2d5156f..5f30806db3 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1749,11 +1749,9 @@ int diskAnnOpenIndex( if( compressNeighbours == 0 ){ pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; - }else if( compressNeighbours == VECTOR_TYPE_FLOAT1BIT ){ + }else{ pIndex->nEdgeVectorType = compressNeighbours; pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); - }else{ - return SQLITE_ERROR; } *ppIndex = pIndex; From 8a2a019b1b1f42d3ed8d5adab598be33a3a7fe36 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 16:19:44 +0400 Subject: [PATCH 091/121] add implementation of float8 vector type (int8 quantization) --- libsql-sqlite3/Makefile.in | 6 +- libsql-sqlite3/src/vector.c | 188 ++++++++++++++++++++++--- libsql-sqlite3/src/vectorIndex.c | 2 + libsql-sqlite3/src/vectorInt.h | 46 ++++-- libsql-sqlite3/src/vectorfloat1bit.c | 7 +- libsql-sqlite3/src/vectorfloat32.c | 25 +--- libsql-sqlite3/src/vectorfloat64.c | 7 +- libsql-sqlite3/src/vectorfloat8.c | 150 ++++++++++++++++++++ libsql-sqlite3/test/libsql_vector.test | 30 ++++ libsql-sqlite3/tool/mksqlite3c.tcl | 1 + 10 files changed, 403 insertions(+), 59 deletions(-) create mode 100644 libsql-sqlite3/src/vectorfloat8.c diff --git a/libsql-sqlite3/Makefile.in b/libsql-sqlite3/Makefile.in index db1e2c55ce..7316257fa4 100644 --- a/libsql-sqlite3/Makefile.in +++ b/libsql-sqlite3/Makefile.in @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \ sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \ table.lo threads.lo tokenize.lo treeview.lo trigger.lo \ update.lo userauth.lo upsert.lo util.lo vacuum.lo \ - vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo \ + vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo \ vectorIndex.lo vectordiskann.lo vectorvtab.lo \ vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \ vdbetrace.lo vdbevtab.lo \ @@ -306,6 +306,7 @@ SRC = \ $(TOP)/src/vectorfloat1bit.c \ $(TOP)/src/vectorfloat32.c \ $(TOP)/src/vectorfloat64.c \ + $(TOP)/src/vectorfloat8.c \ $(TOP)/src/vectorIndexInt.h \ $(TOP)/src/vectorIndex.c \ $(TOP)/src/vectordiskann.c \ @@ -1148,6 +1149,9 @@ vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR) vectorfloat64.lo: $(TOP)/src/vectorfloat64.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat64.c +vectorfloat8.lo: $(TOP)/src/vectorfloat8.c $(HDR) + $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat8.c + vectorIndex.lo: $(TOP)/src/vectorIndex.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorIndex.c diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 01bf402aa0..cf102aeb43 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -43,6 +43,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(double); case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; + case VECTOR_TYPE_FLOAT8: + return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */; default: assert(0); } @@ -116,6 +118,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF64DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceCos(pVector1, pVector2); default: assert(0); } @@ -253,7 +257,8 @@ static int vectorParseSqliteText( } static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ - int nLeftoverBits; + int nTrailingBits; + int nTrailingBytes; if( nBlobSize % 2 == 0 ){ *pType = VECTOR_TYPE_FLOAT32; @@ -266,26 +271,34 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } - nLeftoverBits = pBlob[nBlobSize - 1]; - *pDims = nBlobSize * 8 - nLeftoverBits; + nTrailingBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nTrailingBits; *pDataSize = (*pDims + 7) / 8; + }else if( *pType == VECTOR_TYPE_FLOAT8 ){ + if( nBlobSize < 2 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector: float8 vector blob length must be divisible by 2 and has at least 2 bytes (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nTrailingBytes = pBlob[nBlobSize - 1]; + *pDims = (nBlobSize - 2) - sizeof(float) - sizeof(float) - nTrailingBytes; + *pDataSize = nBlobSize - 2; }else{ *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; @@ -331,6 +344,9 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; + case VECTOR_TYPE_FLOAT8: + vectorF8DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; default: assert(0); } @@ -429,6 +445,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; + case VECTOR_TYPE_FLOAT8: + vectorF8Dump(pVector); + break; default: assert(0); } @@ -451,7 +470,6 @@ void vectorMarshalToText( } static int vectorMetaSize(VectorType type, VectorDims dims){ - int nMetaSize = 0; int nDataSize; if( type == VECTOR_TYPE_FLOAT32 ){ return 0; @@ -459,12 +477,13 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 1; }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); - nMetaSize++; // one byte which specify amount of leftover bits - if( nDataSize % 2 == 0 ){ - nMetaSize++; // pad "leftover-bits" byte to the even length - } - nMetaSize++; // one byte for vector type - return nMetaSize; + // optional padding byte + "trailing-bits" byte + "vector-type" byte + return (nDataSize % 2 == 0 ? 1 : 0) + 1 + 1; + }else if( type == VECTOR_TYPE_FLOAT8 ){ + nDataSize = vectorDataSize(type, dims); + assert( nDataSize % 2 == 0 ); + /* padding byte + "trailing-bytes" byte + "vector-type" byte */ + return 1 + 1 + 1; }else{ assert( 0 ); } @@ -482,6 +501,15 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nBlobSize >= 3 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + if( vectorMetaSize(pVector->type, pVector->dims) == 3 ){ + pBlob[nBlobSize - 3] = 0; + } + }else if( pVector->type == VECTOR_TYPE_FLOAT8 ){ + assert( nBlobSize % 2 == 1 ); + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT8; + pBlob[nBlobSize - 2] = ALIGN(pVector->dims, sizeof(float)) - pVector->dims; }else{ assert( 0 ); } @@ -520,6 +548,9 @@ void vectorSerializeWithMeta( case VECTOR_TYPE_FLOAT1BIT: vector1BitSerializeToBlob(pVector, pBlob, nDataSize); break; + case VECTOR_TYPE_FLOAT8: + vectorF8SerializeToBlob(pVector, pBlob, nDataSize); + break; default: assert(0); } @@ -527,18 +558,20 @@ void vectorSerializeWithMeta( sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } -size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ +void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT64: - return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT1BIT: - return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + break; default: assert(0); } - return 0; } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ @@ -644,6 +677,110 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ } } +static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + float alpha, shift; + + float *dstF32; + double *dstF64; + u8 *dst1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pFrom->data, pFrom->dims, &alpha, &shift); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( (alpha * src[i] + shift) > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); + } +} + +static inline int clip(float f, int minF, int maxF){ + if( f < minF ){ + return minF; + }else if( f > maxF ){ + return maxF; + } + return (int)(f + 0.5); +} + +#define MINMAX(i, value, minValue, maxValue) {if(i == 0){ minValue = (value); maxValue = (value);} else { minValue = MIN(minValue, (value)); maxValue = MAX(maxValue, (value)); }} + +static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *dst; + float alpha, shift; + float minF = 0, maxF = 0; + + float *srcF32; + double *srcF64; + u8 *src1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pTo->type == VECTOR_TYPE_FLOAT8 ); + + dst = pTo->data; + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + srcF32 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF32[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF32[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + srcF64 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF64[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF64[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ + src1Bit = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, ((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1, minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255); + } + }else{ + assert( 0 ); + } + vectorF8SetParameters(pTo->data, pTo->dims, alpha, shift); +} + + void vectorConvert(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); @@ -652,12 +789,16 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ return; } - if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + if( pTo->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertToF8(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertFromF8(pFrom, pTo); }else{ assert( 0 ); } @@ -734,6 +875,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector8Func( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT8); +} + static void vector1BitFunc( sqlite3_context *context, int argc, @@ -873,6 +1022,7 @@ void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), + FUNCTION(vector8, 1, 0, 0, vector8Func), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index a253d0d12b..92d3c8c83e 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -383,6 +383,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, + { "FLOAT8", VECTOR_TYPE_FLOAT8 }, + { "F8_BLOB", VECTOR_TYPE_FLOAT8 }, }; /* diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index a17ff1d59a..39352d6990 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -30,12 +30,12 @@ typedef u32 VectorDims; * - last 'type'-byte is mandatory for float64 vectors * * 3. float1bit - * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [trailing_bits as u8] [3 as u8] * - every data byte (except for the last) represents exactly 8 components of the vector * - last data byte represents [1..8] components of the vector - * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) - * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted - * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - optional padding byte ensures that "trailing_bits" byte will be written at the odd blob position (0-based) + * - "trailing_bits" byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - trailing_bits) * - last 'type'-byte is mandatory for float1bit vectors */ @@ -45,9 +45,12 @@ typedef u32 VectorDims; #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 #define VECTOR_TYPE_FLOAT1BIT 3 +#define VECTOR_TYPE_FLOAT8 4 #define VECTOR_FLAGS_STATIC 1 +#define ALIGN(n, size) (((n + size - 1) / size) * size) + /* * Object which represents a vector * data points to the memory which must be interpreted according to the vector type @@ -68,11 +71,15 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); /* * Dumps vector on the console (used only for debugging) */ -void vectorDump (const Vector *v); +void vectorDump (const Vector *v); +void vectorF8Dump (const Vector *v); void vectorF32Dump (const Vector *v); void vectorF64Dump (const Vector *v); void vector1BitDump(const Vector *v); +void vectorF8GetParameters(const u8 *, int, float *, float *); +void vectorF8SetParameters(u8 *, int, float, float); + /* * Converts vector to the text representation and write the result to the sqlite3_context */ @@ -83,15 +90,17 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); +void vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF8SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceCos (const Vector *, const Vector *); +float vectorF8DistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); @@ -119,6 +128,7 @@ void vectorSerializeWithMeta(sqlite3_context *, const Vector *); */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); +void vectorF8DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -131,6 +141,24 @@ void vectorConvert(const Vector *, Vector *); /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); +static inline unsigned serializeF32(unsigned char *pBuf, float value){ + u32 *p = (u32 *)&value; + pBuf[0] = *p & 0xFF; + pBuf[1] = (*p >> 8) & 0xFF; + pBuf[2] = (*p >> 16) & 0xFF; + pBuf[3] = (*p >> 24) & 0xFF; + return sizeof(float); +} + +static inline float deserializeF32(const unsigned char *pBuf){ + u32 value = 0; + value |= (u32)pBuf[0]; + value |= (u32)pBuf[1] << 8; + value |= (u32)pBuf[2] << 16; + value |= (u32)pBuf[3] << 24; + return *(float *)&value; +} + #ifdef __cplusplus } /* end of the 'extern "C"' block */ #endif diff --git a/libsql-sqlite3/src/vectorfloat1bit.c b/libsql-sqlite3/src/vectorfloat1bit.c index 86e367c2de..a9c4d45f26 100644 --- a/libsql-sqlite3/src/vectorfloat1bit.c +++ b/libsql-sqlite3/src/vectorfloat1bit.c @@ -52,7 +52,7 @@ void vector1BitDump(const Vector *pVec){ ** Utility routines for vector serialization and deserialization **************************************************************************/ -size_t vector1BitSerializeToBlob( +void vector1BitSerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -63,12 +63,11 @@ size_t vector1BitSerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < (pVector->dims + 7) / 8; i++){ pPtr[i] = elems[i]; } - return (pVector->dims + 7) / 8; } // [sum(map(int, bin(i)[2:])) for i in range(256)] @@ -133,7 +132,7 @@ void vector1BitDeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); memcpy(elems, pBlob, (pVector->dims + 7) / 8); } diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 9749a84835..56d022ae9c 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -57,25 +57,7 @@ static inline unsigned formatF32(float value, char *pBuf, int nBufSize){ return strlen(pBuf); } -static inline unsigned serializeF32(unsigned char *pBuf, float value){ - u32 *p = (u32 *)&value; - pBuf[0] = *p & 0xFF; - pBuf[1] = (*p >> 8) & 0xFF; - pBuf[2] = (*p >> 16) & 0xFF; - pBuf[3] = (*p >> 24) & 0xFF; - return sizeof(float); -} - -static inline float deserializeF32(const unsigned char *pBuf){ - u32 value = 0; - value |= (u32)pBuf[0]; - value |= (u32)pBuf[1] << 8; - value |= (u32)pBuf[2] << 16; - value |= (u32)pBuf[3] << 24; - return *(float *)&value; -} - -size_t vectorF32SerializeToBlob( +void vectorF32SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -87,12 +69,11 @@ size_t vectorF32SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ pPtr += serializeF32(pPtr, elems[i]); } - return sizeof(float) * pVector->dims; } #define SINGLE_FLOAT_CHAR_LIMIT 32 @@ -178,7 +159,7 @@ void vectorF32DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index 9f854793ab..dca6bda773 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -83,7 +83,7 @@ static inline double deserializeF64(const unsigned char *pBuf){ return *(double *)&value; } -size_t vectorF64SerializeToBlob( +void vectorF64SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -94,12 +94,11 @@ size_t vectorF64SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for (i = 0; i < pVector->dims; i++) { pPtr += serializeF64(pPtr, elems[i]); } - return sizeof(double) * pVector->dims; } #define SINGLE_DOUBLE_CHAR_LIMIT 32 @@ -185,7 +184,7 @@ void vectorF64DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); diff --git a/libsql-sqlite3/src/vectorfloat8.c b/libsql-sqlite3/src/vectorfloat8.c new file mode 100644 index 0000000000..3438349d9e --- /dev/null +++ b/libsql-sqlite3/src/vectorfloat8.c @@ -0,0 +1,150 @@ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 8-bit (INT8) floating point vector format utilities. +** +** The idea is to replace vector [f_0, f_1, ... f_k] with quantized uint8 values [q_0, q_1, ..., q_k] in such a way that +** f_i = alpha * q_i + shift, when alpha and shift determined from all f_i values like that: +** alpha = (max(f) - min(f)) / 255, shift = min(f) +** +** This differs from uint8 quantization in neural-network as it usually take form of f_i = alpha * (q_i - z) conversion instead +** But, neural-network uint8 quantization is less generic and works better for distributions centered around zero (symmetric or not) +** In our implementation we want to handle more generic cases - so profits from neural-network-style quantization are not clear +*/ +#ifndef SQLITE_OMIT_VECTOR +#include "sqliteInt.h" + +#include "vectorInt.h" + +#include + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +void vectorF8GetParameters(const u8 *pData, int dims, float *pAlpha, float *pShift){ + pData = pData + ALIGN(dims, sizeof(float)); + *pAlpha = deserializeF32(pData); + *pShift = deserializeF32(pData + sizeof(*pAlpha)); +} + +void vectorF8SetParameters(u8 *pData, int dims, float alpha, float shift){ + pData = pData + ALIGN(dims, sizeof(float)); + serializeF32(pData, alpha); + serializeF32(pData + sizeof(alpha), shift); +} + +void vectorF8Dump(const Vector *pVec){ + u8 *elems = pVec->data; + float alpha, shift; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pVec->data, pVec->dims, &alpha, &shift); + + printf("f8: ["); + for(i = 0; i < pVec->dims; i++){ + printf("%s%f", i == 0 ? "" : ", ", (float)elems[i] * alpha + shift); + } + printf("]\n"); +} + +void vectorF8SerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy(pBlob, pVector->data, pVector->dims); + + vectorF8GetParameters(pVector->data, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pBlob, pVector->dims, alpha, shift); +} + +float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + u32 sum1 = 0, sum2 = 0, sumsq1 = 0, sumsq2 = 0, doti = 0; + float dot = 0, norm1 = 0, norm2 = 0; + u8 *data1 = v1->data, *data2 = v2->data; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v1->dims, &alpha2, &shift2); + + /* + * (Ax + S)^2 = A^2 x^2 + S^2 + 2AS x -> we need to maintain 'sumsq' and 'sum' + * (A1x + S1) * (A2y + S2) = A1A2 xy + A1 S2 x + A2 S1 y + S1 S2 -> we need to maintain 'dot' and 'sum' again + */ + + for(i = 0; i < v1->dims; i++){ + sum1 += data1[i]; + sum2 += data2[i]; + sumsq1 += data1[i]*data1[i]; + sumsq2 += data2[i]*data2[i]; + doti += data1[i] * data2[i]; + } + + dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2; + norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1; + norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2; + + return 1.0 - (dot / sqrt(norm1 * norm2)); +} + +float vectorF8DistanceL2(const Vector *v1, const Vector *v2){ + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + assert( 0 ); +} + +void vectorF8DeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy((u8*)pVector->data, (u8*)pBlob, ALIGN(pVector->dims, sizeof(float))); + + vectorF8GetParameters(pBlob, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pVector->data, pVector->dims, alpha, shift); +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index be2edc9397..bb1e063325 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -53,6 +53,15 @@ do_execsql_test vector-1-func-valid { SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[-5,4]')); SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,4]')); SELECT vector_distance_cos(vector1bit('[10,-10]'), vector1bit('[20,-2]')); + + SELECT vector_distance_cos(vector8('[10,-10]'), vector8('[10,-10]')); + SELECT vector_distance_cos(vector32('[10,-10]'), vector32('[10,-10]')); + + SELECT vector_distance_cos(vector8('[-21,-31,0,2,2.1,2.2,105]'), vector8('[-20,-30,0,1,1.1,1.2,100]')); + SELECT vector_distance_cos(vector32('[-21,-31,0,2,2.1,2.2,105]'), vector32('[-20,-30,0,1,1.1,1.2,100]')); + + SELECT vector_distance_cos(vector8('[-20,-30,0,1,1.1,1.2,100]'), vector8('[-20,-30,0,1,1.1,1.2,10000]')); + SELECT vector_distance_cos(vector32('[-20,-30,0,1,1.1,1.2,100]'), vector32('[-20,-30,0,1,1.1,1.2,10000]')); } { {[]} {[]} @@ -71,6 +80,9 @@ do_execsql_test vector-1-func-valid { {2.0} {1.0} {0.0} + {-1.22070709096533e-08} {0.0} + {1.54134213516954e-05} {0.000117244853754528} + {-0.297326117753983} {0.0582110174000263} } do_execsql_test vector-1-conversion { @@ -88,6 +100,15 @@ do_execsql_test vector-1-conversion { SELECT vector_extract(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector1bit('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); SELECT vector_extract(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector32('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); SELECT vector_extract(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))), hex(vector1bit(vector64('[-0.000001,1e-100,1e100,-1e10,1e-10,0,1.5]'))); + + SELECT vector_extract(vector8(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector1bit('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + SELECT vector_extract(vector8(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector32('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + SELECT vector_extract(vector8(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector64('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + SELECT vector_extract(vector8(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector8(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + + SELECT vector_extract(vector1bit(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector1bit(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + SELECT vector_extract(vector32(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector32(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); + SELECT vector_extract(vector64(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))), hex(vector64(vector8('[-20,-35.44,1,1.5,2,3,10,100,105,110]'))); } { {} 02 @@ -103,6 +124,15 @@ do_execsql_test vector-1-conversion { {[-1,-1,1,-1,1,-1,1]} 540903 {[-1,-1,1,-1,1,-1,1]} 540903 {[-1,1,1,-1,1,-1,1]} 560903 + + {[-1,-1,1,1,1,1,1,1,1,1]} 0000FFFFFFFFFFFFFFFF00008180003C000080BF000204 + {[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204 + {[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204 + {[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 1B004041424350EDF6FF0000A702123F8FC20DC2000204 + + {[-1,-1,1,1,1,1,1,1,1,1]} FC03001603 + {[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} E152A0C18FC20DC20003883F6004D13FD0020D408083314008032341A277C742D0BBD1420000DC42 + {[-20.0405,-35.44,1.06259,1.63295,2.2033,2.77365,10.1882,99.7337,104.867,110]} 000000205C0A34C0000000E051B841C0000000006000F13F000000008C20FA3F000000005AA001400000000070300640000000006160244000000040F4EE5840000000007A375A400000000000805B4002 } proc error_messages {sql} { diff --git a/libsql-sqlite3/tool/mksqlite3c.tcl b/libsql-sqlite3/tool/mksqlite3c.tcl index fe76abc0a8..560992a60b 100644 --- a/libsql-sqlite3/tool/mksqlite3c.tcl +++ b/libsql-sqlite3/tool/mksqlite3c.tcl @@ -472,6 +472,7 @@ set flist { vectorfloat1bit.c vectorfloat32.c vectorfloat64.c + vectorfloat8.c vectorIndex.c vectorvtab.c rtree.c From 244a34a16f2ad4c4ef4c2502f574fe0e26c37e2c Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 17:21:38 +0400 Subject: [PATCH 092/121] fix bug in cosine distance --- libsql-sqlite3/src/vector.c | 20 ++++---------------- libsql-sqlite3/src/vectorfloat8.c | 12 ++++++------ libsql-sqlite3/test/libsql_vector.test | 7 ++++--- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index cf102aeb43..6d5191f33d 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -538,22 +538,7 @@ void vectorSerializeWithMeta( return; } - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT1BIT: - vector1BitSerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT8: - vectorF8SerializeToBlob(pVector, pBlob, nDataSize); - break; - default: - assert(0); - } + vectorSerializeToBlob(pVector, pBlob, nDataSize); vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } @@ -569,6 +554,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n case VECTOR_TYPE_FLOAT1BIT: vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); break; + case VECTOR_TYPE_FLOAT8: + vectorF8SerializeToBlob(pVector, pBlob, nBlobSize); + break; default: assert(0); } diff --git a/libsql-sqlite3/src/vectorfloat8.c b/libsql-sqlite3/src/vectorfloat8.c index 3438349d9e..3e84e50844 100644 --- a/libsql-sqlite3/src/vectorfloat8.c +++ b/libsql-sqlite3/src/vectorfloat8.c @@ -100,10 +100,10 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ assert( v2->type == VECTOR_TYPE_FLOAT8 ); vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); - vectorF8GetParameters(v2->data, v1->dims, &alpha2, &shift2); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); /* - * (Ax + S)^2 = A^2 x^2 + S^2 + 2AS x -> we need to maintain 'sumsq' and 'sum' + * (Ax + S)^2 = A^2 x^2 + 2AS x + S^2 -> we need to maintain 'sumsq' and 'sum' * (A1x + S1) * (A2y + S2) = A1A2 xy + A1 S2 x + A2 S1 y + S1 S2 -> we need to maintain 'dot' and 'sum' again */ @@ -112,12 +112,12 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ sum2 += data2[i]; sumsq1 += data1[i]*data1[i]; sumsq2 += data2[i]*data2[i]; - doti += data1[i] * data2[i]; + doti += data1[i]*data2[i]; } - dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2; - norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1; - norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2; + dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2 * v1->dims; + norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1 * v1->dims; + norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2 * v1->dims; return 1.0 - (dot / sqrt(norm1 * norm2)); } diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index bb1e063325..36b5b3619c 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -80,9 +80,10 @@ do_execsql_test vector-1-func-valid { {2.0} {1.0} {0.0} - {-1.22070709096533e-08} {0.0} - {1.54134213516954e-05} {0.000117244853754528} - {-0.297326117753983} {0.0582110174000263} + + {-6.10352568486405e-09} {0.0} + {0.000111237335659098} {0.000117244853754528} + {0.0576796568930149} {0.0582110174000263} } do_execsql_test vector-1-conversion { From 34c297451e752ce69e5d7f2f61ef9ea28ba66f80 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 17:41:11 +0400 Subject: [PATCH 093/121] support float8 neighbors compression --- libsql-sqlite3/src/vectorIndex.c | 1 + libsql-sqlite3/test/libsql_vector_index.test | 29 ++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 92d3c8c83e..96f5b450c4 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -404,6 +404,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = { { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index f066275833..1de85f0493 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -345,6 +345,33 @@ do_execsql_test vector-f64-compress-f32 { SELECT * FROM vector_top_k('t_f64_f32_idx', vector64('[10,-10,-20,20]'), 4); } {3 1 2} +do_execsql_test vector-f32-compress-f8 { + CREATE TABLE t_f32_f8( v FLOAT32(4) ); + CREATE INDEX t_f32_f8_idx ON t_f32_f8( libsql_vector_idx(v, 'compress_neighbors=float8') ); + INSERT INTO t_f32_f8 VALUES ( vector('[1,-1,1,-1]') ); + INSERT INTO t_f32_f8 VALUES ( vector('[-1,1,1,-1]') ); + INSERT INTO t_f32_f8 VALUES ( vector('[1,-1,-1,1]') ); + SELECT * FROM vector_top_k('t_f32_f8_idx', vector('[10,-10,-20,20]'), 4); +} {3 1 2} + +do_execsql_test vector-f8 { + CREATE TABLE t_f8( v FLOAT8(4) ); + CREATE INDEX t_f8_idx ON t_f8( libsql_vector_idx(v) ); + INSERT INTO t_f8 VALUES ( vector8('[1,-1,1,-1]') ); + INSERT INTO t_f8 VALUES ( vector8('[-1,1,1,-1]') ); + INSERT INTO t_f8 VALUES ( vector8('[1,-1,-1,1]') ); + SELECT * FROM vector_top_k('t_f8_idx', vector8('[10,-10,-20,20]'), 4); +} {3 1 2} + +do_execsql_test vector-f8-compress-1bit { + CREATE TABLE t_f8_1bit( v FLOAT8(4) ); + CREATE INDEX t_f8_1bit_idx ON t_f8_1bit( libsql_vector_idx(v, 'compress_neighbors=float1bit') ); + INSERT INTO t_f8_1bit VALUES ( vector8('[1,-1,1,-1]') ); + INSERT INTO t_f8_1bit VALUES ( vector8('[-1,1,1,-1]') ); + INSERT INTO t_f8_1bit VALUES ( vector8('[1,-1,-1,1]') ); + SELECT * FROM vector_top_k('t_f8_1bit_idx', vector8('[10,-10,-20,20]'), 4); +} {3 1 2} + proc error_messages {sql} { set ret "" catch { @@ -355,6 +382,8 @@ proc error_messages {sql} { set ret [sqlite3_errmsg db] } +reset_db + do_test vector-errors { set ret [list] lappend ret [error_messages {CREATE INDEX t_no_idx ON t_no( libsql_vector_idx(v) )}] From 562680c18d3418e6280fa511e283ad71c53e040c Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 17:47:26 +0400 Subject: [PATCH 094/121] add simple description of the float8 layout --- libsql-sqlite3/src/vectorInt.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 39352d6990..f8806a143e 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -37,6 +37,13 @@ typedef u32 VectorDims; * - "trailing_bits" byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted * (so, vector dimensions are equal to 8 * (blob_size - 1) - trailing_bits) * - last 'type'-byte is mandatory for float1bit vectors + * + * 4. float8 + * [data[0] as u8] [data[1] as u8] ... [data[dims - 1] as u8] [_ as u8; alignment_padding]* [alpha as f32] [shift as f32] [padding as u8] [trailing_bytes as u8] [4 as u8] + * - every data byte represents single quantized vector component + * - "alignment_padding" has size from 0 to 3 bytes in order to pad content to multiple of 4 = sizeof(float) + * - "trailing_bytes" byte specify amount of bytes in the "alignment_padding" + * - last 'type'-byte is mandatory for float8 vectors */ /* From 2e0647fd201ca4f4a6dfc8b514ae700fa1503e1d Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 18:02:25 +0400 Subject: [PATCH 095/121] expose vector_distance_l2 func - we had it before but it's harder to add tests for l2 metric without it --- libsql-sqlite3/src/vector.c | 27 +++++++++++++++++++++------ libsql-sqlite3/src/vectorInt.h | 1 + libsql-sqlite3/src/vectorfloat8.c | 14 +++++++++++++- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 6d5191f33d..b1ac6d6fb2 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -133,6 +133,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceL2(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceL2(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceL2(pVector1, pVector2); default: assert(0); } @@ -928,13 +930,11 @@ static void vectorExtractFunc( } } -/* -** Implementation of vector_distance_cos(X, Y) function. -*/ -static void vectorDistanceCosFunc( +static void vectorDistanceFunc( sqlite3_context *context, int argc, - sqlite3_value **argv + sqlite3_value **argv, + float (*vectorDistance)(const Vector *pVector1, const Vector *pVector2) ){ char *pzErrMsg = NULL; Vector *pVector1 = NULL, *pVector2 = NULL; @@ -983,7 +983,7 @@ static void vectorDistanceCosFunc( sqlite3_free(pzErrMsg); goto out_free; } - sqlite3_result_double(context, vectorDistanceCos(pVector1, pVector2)); + sqlite3_result_double(context, vectorDistance(pVector1, pVector2)); out_free: if( pVector2 ){ vectorFree(pVector2); @@ -993,6 +993,20 @@ static void vectorDistanceCosFunc( } } +/* +** Implementation of vector_distance_cos(X, Y) function. +*/ +static void vectorDistanceCosFunc(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceCos); +} + +/* +** Implementation of vector_distance_l2(X, Y) function. +*/ +static void vectorDistanceL2Func(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceL2); +} + /* * Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb)); */ @@ -1013,6 +1027,7 @@ void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector8, 1, 0, 0, vector8Func), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), + FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func), FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx), }; diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index f8806a143e..1c72857326 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -120,6 +120,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *); * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceL2 (const Vector *, const Vector *); +float vectorF8DistanceL2 (const Vector *, const Vector *); float vectorF32DistanceL2 (const Vector *, const Vector *); double vectorF64DistanceL2(const Vector *, const Vector *); diff --git a/libsql-sqlite3/src/vectorfloat8.c b/libsql-sqlite3/src/vectorfloat8.c index 3e84e50844..dd02d839b4 100644 --- a/libsql-sqlite3/src/vectorfloat8.c +++ b/libsql-sqlite3/src/vectorfloat8.c @@ -123,11 +123,23 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ } float vectorF8DistanceL2(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + float sum = 0; + u8 *data1 = v1->data, *data2 = v2->data; + assert( v1->dims == v2->dims ); assert( v1->type == VECTOR_TYPE_FLOAT8 ); assert( v2->type == VECTOR_TYPE_FLOAT8 ); - assert( 0 ); + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); + + for(i = 0; i < v1->dims; i++){ + float d = (alpha1 * data1[i] + shift1) - (alpha2 * data2[i] + shift2); + sum += d*d; + } + return sqrt(sum); } void vectorF8DeserializeFromBlob( From bd253445737c576ef10a272ba48d9f49a310fdcd Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 18:08:08 +0400 Subject: [PATCH 096/121] add tests and refine error messages --- libsql-sqlite3/src/vector.c | 10 ++++++++-- libsql-sqlite3/test/libsql_vector.test | 11 +++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index b1ac6d6fb2..8ec485c472 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -954,13 +954,19 @@ static void vectorDistanceFunc( goto out_free; } if( type1 != type2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same type: %d != %d", type1, type2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same type: %d != %d", type1, type2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } if( dims1 != dims2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same length: %d != %d", dims1, dims2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same length: %d != %d", dims1, dims2); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( vectorDistance == vectorDistanceL2 && type1 == VECTOR_TYPE_FLOAT1BIT ){ + pzErrMsg = sqlite3_mprintf("vector_distance: l2 distance is not supported for float1bit vectors", dims1, dims2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; diff --git a/libsql-sqlite3/test/libsql_vector.test b/libsql-sqlite3/test/libsql_vector.test index 36b5b3619c..31d1edc843 100644 --- a/libsql-sqlite3/test/libsql_vector.test +++ b/libsql-sqlite3/test/libsql_vector.test @@ -62,6 +62,9 @@ do_execsql_test vector-1-func-valid { SELECT vector_distance_cos(vector8('[-20,-30,0,1,1.1,1.2,100]'), vector8('[-20,-30,0,1,1.1,1.2,10000]')); SELECT vector_distance_cos(vector32('[-20,-30,0,1,1.1,1.2,100]'), vector32('[-20,-30,0,1,1.1,1.2,10000]')); + + SELECT vector_distance_l2(vector('[1,2,2,3,4,1,5]'), vector('[2,3,1,-1,2,4,5]')); + SELECT vector_distance_l2(vector8('[1,2,2,3,4,1,5]'), vector8('[2,3,1,-1,2,4,5]')); } { {[]} {[]} @@ -84,6 +87,8 @@ do_execsql_test vector-1-func-valid { {-6.10352568486405e-09} {0.0} {0.000111237335659098} {0.000117244853754528} {0.0576796568930149} {0.0582110174000263} + + {5.65685415267944} {5.65413522720337} } do_execsql_test vector-1-conversion { @@ -158,6 +163,7 @@ do_test vector-1-func-errors { lappend ret [error_messages {SELECT vector(x'0000000000')}] lappend ret [error_messages {SELECT vector_distance_cos('[1,2,3]', '[1,2]')}] lappend ret [error_messages {SELECT vector_distance_cos(vector32('[1,2,3]'), vector64('[1,2,3]'))}] + lappend ret [error_messages {SELECT vector_distance_l2(vector1bit('[1,2,2,3,4,1,5]'), vector1bit('[2,3,1,-1,2,4,5]'))}] } [list {*}{ {vector: unexpected value type: got FLOAT, expected TEXT or BLOB} {vector: unexpected value type: got INTEGER, expected TEXT or BLOB} @@ -169,6 +175,7 @@ do_test vector-1-func-errors { {vector: invalid float at position 2: '1.1.1'} {vector: must end with ']'} {vector: unexpected binary type: 0} - {vector_distance_cos: vectors must have the same length: 3 != 2} - {vector_distance_cos: vectors must have the same type: 1 != 2} + {vector_distance: vectors must have the same length: 3 != 2} + {vector_distance: vectors must have the same type: 1 != 2} + {vector_distance: l2 distance is not supported for float1bit vectors} }] From 4303c65d32660b7ab227132b5e33bf7f783e4fc4 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 13 Aug 2024 18:09:23 +0400 Subject: [PATCH 097/121] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 500 +++++++++++++++--- libsql-ffi/bundled/src/sqlite3.c | 500 +++++++++++++++--- 2 files changed, 842 insertions(+), 158 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index e48ad61292..493cfd263b 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85249,13 +85249,20 @@ typedef u32 VectorDims; * - last 'type'-byte is mandatory for float64 vectors * * 3. float1bit - * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [trailing_bits as u8] [3 as u8] * - every data byte (except for the last) represents exactly 8 components of the vector * - last data byte represents [1..8] components of the vector - * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) - * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted - * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - optional padding byte ensures that "trailing_bits" byte will be written at the odd blob position (0-based) + * - "trailing_bits" byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - trailing_bits) * - last 'type'-byte is mandatory for float1bit vectors + * + * 4. float8 + * [data[0] as u8] [data[1] as u8] ... [data[dims - 1] as u8] [_ as u8; alignment_padding]* [alpha as f32] [shift as f32] [padding as u8] [trailing_bytes as u8] [4 as u8] + * - every data byte represents single quantized vector component + * - "alignment_padding" has size from 0 to 3 bytes in order to pad content to multiple of 4 = sizeof(float) + * - "trailing_bytes" byte specify amount of bytes in the "alignment_padding" + * - last 'type'-byte is mandatory for float8 vectors */ /* @@ -85264,9 +85271,12 @@ typedef u32 VectorDims; #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 #define VECTOR_TYPE_FLOAT1BIT 3 +#define VECTOR_TYPE_FLOAT8 4 #define VECTOR_FLAGS_STATIC 1 +#define ALIGN(n, size) (((n + size - 1) / size) * size) + /* * Object which represents a vector * data points to the memory which must be interpreted according to the vector type @@ -85287,11 +85297,15 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); /* * Dumps vector on the console (used only for debugging) */ -void vectorDump (const Vector *v); +void vectorDump (const Vector *v); +void vectorF8Dump (const Vector *v); void vectorF32Dump (const Vector *v); void vectorF64Dump (const Vector *v); void vector1BitDump(const Vector *v); +void vectorF8GetParameters(const u8 *, int, float *, float *); +void vectorF8SetParameters(u8 *, int, float, float); + /* * Converts vector to the text representation and write the result to the sqlite3_context */ @@ -85302,15 +85316,17 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); +void vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF8SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceCos (const Vector *, const Vector *); +float vectorF8DistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); @@ -85323,6 +85339,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *); * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceL2 (const Vector *, const Vector *); +float vectorF8DistanceL2 (const Vector *, const Vector *); float vectorF32DistanceL2 (const Vector *, const Vector *); double vectorF64DistanceL2(const Vector *, const Vector *); @@ -85338,6 +85355,7 @@ void vectorSerializeWithMeta(sqlite3_context *, const Vector *); */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); +void vectorF8DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -85350,6 +85368,24 @@ void vectorConvert(const Vector *, Vector *); /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); +static inline unsigned serializeF32(unsigned char *pBuf, float value){ + u32 *p = (u32 *)&value; + pBuf[0] = *p & 0xFF; + pBuf[1] = (*p >> 8) & 0xFF; + pBuf[2] = (*p >> 16) & 0xFF; + pBuf[3] = (*p >> 24) & 0xFF; + return sizeof(float); +} + +static inline float deserializeF32(const unsigned char *pBuf){ + u32 value = 0; + value |= (u32)pBuf[0]; + value |= (u32)pBuf[1] << 8; + value |= (u32)pBuf[2] << 16; + value |= (u32)pBuf[3] << 24; + return *(float *)&value; +} + #if 0 } /* end of the 'extern "C"' block */ #endif @@ -211001,6 +211037,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(double); case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; + case VECTOR_TYPE_FLOAT8: + return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */; default: assert(0); } @@ -211074,6 +211112,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF64DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceCos(pVector1, pVector2); default: assert(0); } @@ -211087,6 +211127,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceL2(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceL2(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceL2(pVector1, pVector2); default: assert(0); } @@ -211211,7 +211253,8 @@ static int vectorParseSqliteText( } static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ - int nLeftoverBits; + int nTrailingBits; + int nTrailingBytes; if( nBlobSize % 2 == 0 ){ *pType = VECTOR_TYPE_FLOAT32; @@ -211224,26 +211267,34 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } - nLeftoverBits = pBlob[nBlobSize - 1]; - *pDims = nBlobSize * 8 - nLeftoverBits; + nTrailingBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nTrailingBits; *pDataSize = (*pDims + 7) / 8; + }else if( *pType == VECTOR_TYPE_FLOAT8 ){ + if( nBlobSize < 2 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector: float8 vector blob length must be divisible by 2 and has at least 2 bytes (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nTrailingBytes = pBlob[nBlobSize - 1]; + *pDims = (nBlobSize - 2) - sizeof(float) - sizeof(float) - nTrailingBytes; + *pDataSize = nBlobSize - 2; }else{ *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; @@ -211289,6 +211340,9 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; + case VECTOR_TYPE_FLOAT8: + vectorF8DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; default: assert(0); } @@ -211387,6 +211441,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; + case VECTOR_TYPE_FLOAT8: + vectorF8Dump(pVector); + break; default: assert(0); } @@ -211409,7 +211466,6 @@ void vectorMarshalToText( } static int vectorMetaSize(VectorType type, VectorDims dims){ - int nMetaSize = 0; int nDataSize; if( type == VECTOR_TYPE_FLOAT32 ){ return 0; @@ -211417,12 +211473,13 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 1; }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); - nMetaSize++; // one byte which specify amount of leftover bits - if( nDataSize % 2 == 0 ){ - nMetaSize++; // pad "leftover-bits" byte to the even length - } - nMetaSize++; // one byte for vector type - return nMetaSize; + // optional padding byte + "trailing-bits" byte + "vector-type" byte + return (nDataSize % 2 == 0 ? 1 : 0) + 1 + 1; + }else if( type == VECTOR_TYPE_FLOAT8 ){ + nDataSize = vectorDataSize(type, dims); + assert( nDataSize % 2 == 0 ); + /* padding byte + "trailing-bytes" byte + "vector-type" byte */ + return 1 + 1 + 1; }else{ assert( 0 ); } @@ -211440,6 +211497,15 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nBlobSize >= 3 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + if( vectorMetaSize(pVector->type, pVector->dims) == 3 ){ + pBlob[nBlobSize - 3] = 0; + } + }else if( pVector->type == VECTOR_TYPE_FLOAT8 ){ + assert( nBlobSize % 2 == 1 ); + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT8; + pBlob[nBlobSize - 2] = ALIGN(pVector->dims, sizeof(float)) - pVector->dims; }else{ assert( 0 ); } @@ -211468,35 +211534,28 @@ void vectorSerializeWithMeta( return; } - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT1BIT: - vector1BitSerializeToBlob(pVector, pBlob, nDataSize); - break; - default: - assert(0); - } + vectorSerializeToBlob(pVector, pBlob, nDataSize); vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } -size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ +void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT64: - return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT1BIT: - return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + break; + case VECTOR_TYPE_FLOAT8: + vectorF8SerializeToBlob(pVector, pBlob, nBlobSize); + break; default: assert(0); } - return 0; } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ @@ -211602,6 +211661,110 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ } } +static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + float alpha, shift; + + float *dstF32; + double *dstF64; + u8 *dst1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pFrom->data, pFrom->dims, &alpha, &shift); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( (alpha * src[i] + shift) > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); + } +} + +static inline int clip(float f, int minF, int maxF){ + if( f < minF ){ + return minF; + }else if( f > maxF ){ + return maxF; + } + return (int)(f + 0.5); +} + +#define MINMAX(i, value, minValue, maxValue) {if(i == 0){ minValue = (value); maxValue = (value);} else { minValue = MIN(minValue, (value)); maxValue = MAX(maxValue, (value)); }} + +static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *dst; + float alpha, shift; + float minF = 0, maxF = 0; + + float *srcF32; + double *srcF64; + u8 *src1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pTo->type == VECTOR_TYPE_FLOAT8 ); + + dst = pTo->data; + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + srcF32 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF32[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF32[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + srcF64 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF64[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF64[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ + src1Bit = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, ((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1, minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255); + } + }else{ + assert( 0 ); + } + vectorF8SetParameters(pTo->data, pTo->dims, alpha, shift); +} + + void vectorConvert(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); @@ -211610,12 +211773,16 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ return; } - if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + if( pTo->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertToF8(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertFromF8(pFrom, pTo); }else{ assert( 0 ); } @@ -211692,6 +211859,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector8Func( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT8); +} + static void vector1BitFunc( sqlite3_context *context, int argc, @@ -211749,13 +211924,11 @@ static void vectorExtractFunc( } } -/* -** Implementation of vector_distance_cos(X, Y) function. -*/ -static void vectorDistanceCosFunc( +static void vectorDistanceFunc( sqlite3_context *context, int argc, - sqlite3_value **argv + sqlite3_value **argv, + float (*vectorDistance)(const Vector *pVector1, const Vector *pVector2) ){ char *pzErrMsg = NULL; Vector *pVector1 = NULL, *pVector2 = NULL; @@ -211775,13 +211948,19 @@ static void vectorDistanceCosFunc( goto out_free; } if( type1 != type2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same type: %d != %d", type1, type2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same type: %d != %d", type1, type2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } if( dims1 != dims2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same length: %d != %d", dims1, dims2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same length: %d != %d", dims1, dims2); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( vectorDistance == vectorDistanceL2 && type1 == VECTOR_TYPE_FLOAT1BIT ){ + pzErrMsg = sqlite3_mprintf("vector_distance: l2 distance is not supported for float1bit vectors", dims1, dims2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211804,7 +211983,7 @@ static void vectorDistanceCosFunc( sqlite3_free(pzErrMsg); goto out_free; } - sqlite3_result_double(context, vectorDistanceCos(pVector1, pVector2)); + sqlite3_result_double(context, vectorDistance(pVector1, pVector2)); out_free: if( pVector2 ){ vectorFree(pVector2); @@ -211814,6 +211993,20 @@ static void vectorDistanceCosFunc( } } +/* +** Implementation of vector_distance_cos(X, Y) function. +*/ +static void vectorDistanceCosFunc(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceCos); +} + +/* +** Implementation of vector_distance_l2(X, Y) function. +*/ +static void vectorDistanceL2Func(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceL2); +} + /* * Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb)); */ @@ -211831,8 +212024,10 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), + FUNCTION(vector8, 1, 0, 0, vector8Func), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), + FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func), FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx), }; @@ -213674,7 +213869,7 @@ void vector1BitDump(const Vector *pVec){ ** Utility routines for vector serialization and deserialization **************************************************************************/ -size_t vector1BitSerializeToBlob( +void vector1BitSerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -213685,12 +213880,11 @@ size_t vector1BitSerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < (pVector->dims + 7) / 8; i++){ pPtr[i] = elems[i]; } - return (pVector->dims + 7) / 8; } // [sum(map(int, bin(i)[2:])) for i in range(256)] @@ -213755,7 +213949,7 @@ void vector1BitDeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); memcpy(elems, pBlob, (pVector->dims + 7) / 8); } @@ -213823,25 +214017,7 @@ static inline unsigned formatF32(float value, char *pBuf, int nBufSize){ return strlen(pBuf); } -static inline unsigned serializeF32(unsigned char *pBuf, float value){ - u32 *p = (u32 *)&value; - pBuf[0] = *p & 0xFF; - pBuf[1] = (*p >> 8) & 0xFF; - pBuf[2] = (*p >> 16) & 0xFF; - pBuf[3] = (*p >> 24) & 0xFF; - return sizeof(float); -} - -static inline float deserializeF32(const unsigned char *pBuf){ - u32 value = 0; - value |= (u32)pBuf[0]; - value |= (u32)pBuf[1] << 8; - value |= (u32)pBuf[2] << 16; - value |= (u32)pBuf[3] << 24; - return *(float *)&value; -} - -size_t vectorF32SerializeToBlob( +void vectorF32SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -213853,12 +214029,11 @@ size_t vectorF32SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ pPtr += serializeF32(pPtr, elems[i]); } - return sizeof(float) * pVector->dims; } #define SINGLE_FLOAT_CHAR_LIMIT 32 @@ -213944,7 +214119,7 @@ void vectorF32DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); @@ -214041,7 +214216,7 @@ static inline double deserializeF64(const unsigned char *pBuf){ return *(double *)&value; } -size_t vectorF64SerializeToBlob( +void vectorF64SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -214052,12 +214227,11 @@ size_t vectorF64SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for (i = 0; i < pVector->dims; i++) { pPtr += serializeF64(pPtr, elems[i]); } - return sizeof(double) * pVector->dims; } #define SINGLE_DOUBLE_CHAR_LIMIT 32 @@ -214143,7 +214317,7 @@ void vectorF64DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); @@ -214154,6 +214328,171 @@ void vectorF64DeserializeFromBlob( #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vectorfloat64.c ***************************************/ +/************** Begin file vectorfloat8.c ************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 8-bit (INT8) floating point vector format utilities. +** +** The idea is to replace vector [f_0, f_1, ... f_k] with quantized uint8 values [q_0, q_1, ..., q_k] in such a way that +** f_i = alpha * q_i + shift, when alpha and shift determined from all f_i values like that: +** alpha = (max(f) - min(f)) / 255, shift = min(f) +** +** This differs from uint8 quantization in neural-network as it usually take form of f_i = alpha * (q_i - z) conversion instead +** But, neural-network uint8 quantization is less generic and works better for distributions centered around zero (symmetric or not) +** In our implementation we want to handle more generic cases - so profits from neural-network-style quantization are not clear +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +void vectorF8GetParameters(const u8 *pData, int dims, float *pAlpha, float *pShift){ + pData = pData + ALIGN(dims, sizeof(float)); + *pAlpha = deserializeF32(pData); + *pShift = deserializeF32(pData + sizeof(*pAlpha)); +} + +void vectorF8SetParameters(u8 *pData, int dims, float alpha, float shift){ + pData = pData + ALIGN(dims, sizeof(float)); + serializeF32(pData, alpha); + serializeF32(pData + sizeof(alpha), shift); +} + +void vectorF8Dump(const Vector *pVec){ + u8 *elems = pVec->data; + float alpha, shift; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pVec->data, pVec->dims, &alpha, &shift); + + printf("f8: ["); + for(i = 0; i < pVec->dims; i++){ + printf("%s%f", i == 0 ? "" : ", ", (float)elems[i] * alpha + shift); + } + printf("]\n"); +} + +void vectorF8SerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy(pBlob, pVector->data, pVector->dims); + + vectorF8GetParameters(pVector->data, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pBlob, pVector->dims, alpha, shift); +} + +float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + u32 sum1 = 0, sum2 = 0, sumsq1 = 0, sumsq2 = 0, doti = 0; + float dot = 0, norm1 = 0, norm2 = 0; + u8 *data1 = v1->data, *data2 = v2->data; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); + + /* + * (Ax + S)^2 = A^2 x^2 + 2AS x + S^2 -> we need to maintain 'sumsq' and 'sum' + * (A1x + S1) * (A2y + S2) = A1A2 xy + A1 S2 x + A2 S1 y + S1 S2 -> we need to maintain 'dot' and 'sum' again + */ + + for(i = 0; i < v1->dims; i++){ + sum1 += data1[i]; + sum2 += data2[i]; + sumsq1 += data1[i]*data1[i]; + sumsq2 += data2[i]*data2[i]; + doti += data1[i]*data2[i]; + } + + dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2 * v1->dims; + norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1 * v1->dims; + norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2 * v1->dims; + + return 1.0 - (dot / sqrt(norm1 * norm2)); +} + +float vectorF8DistanceL2(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + float sum = 0; + u8 *data1 = v1->data, *data2 = v2->data; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); + + for(i = 0; i < v1->dims; i++){ + float d = (alpha1 * data1[i] + shift1) - (alpha2 * data2[i] + shift2); + sum += d*d; + } + return sqrt(sum); +} + +void vectorF8DeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy((u8*)pVector->data, (u8*)pBlob, ALIGN(pVector->dims, sizeof(float))); + + vectorF8GetParameters(pBlob, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pVector->data, pVector->dims, alpha, shift); +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vectorfloat8.c ****************************************/ /************** Begin file vectorIndex.c *************************************/ /* ** 2024-03-18 @@ -214540,6 +214879,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, + { "FLOAT8", VECTOR_TYPE_FLOAT8 }, + { "F8_BLOB", VECTOR_TYPE_FLOAT8 }, }; /* @@ -214559,6 +214900,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = { { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index e48ad61292..493cfd263b 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85249,13 +85249,20 @@ typedef u32 VectorDims; * - last 'type'-byte is mandatory for float64 vectors * * 3. float1bit - * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [leftover as u8] [3 as u8] + * [data[0] as u8] [data[1] as u8] ... [data[(dims + 7) / 8] as u8] [_ as u8; padding]? [trailing_bits as u8] [3 as u8] * - every data byte (except for the last) represents exactly 8 components of the vector * - last data byte represents [1..8] components of the vector - * - optional padding byte ensures that leftover byte will be written at the odd blob position (0-based) - * - leftover byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted - * (so, vector dimensions are equal to 8 * (blob_size - 1) - leftover) + * - optional padding byte ensures that "trailing_bits" byte will be written at the odd blob position (0-based) + * - "trailing_bits" byte specify amount of trailing *bits* in the blob without last 'type'-byte which must be omitted + * (so, vector dimensions are equal to 8 * (blob_size - 1) - trailing_bits) * - last 'type'-byte is mandatory for float1bit vectors + * + * 4. float8 + * [data[0] as u8] [data[1] as u8] ... [data[dims - 1] as u8] [_ as u8; alignment_padding]* [alpha as f32] [shift as f32] [padding as u8] [trailing_bytes as u8] [4 as u8] + * - every data byte represents single quantized vector component + * - "alignment_padding" has size from 0 to 3 bytes in order to pad content to multiple of 4 = sizeof(float) + * - "trailing_bytes" byte specify amount of bytes in the "alignment_padding" + * - last 'type'-byte is mandatory for float8 vectors */ /* @@ -85264,9 +85271,12 @@ typedef u32 VectorDims; #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 #define VECTOR_TYPE_FLOAT1BIT 3 +#define VECTOR_TYPE_FLOAT8 4 #define VECTOR_FLAGS_STATIC 1 +#define ALIGN(n, size) (((n + size - 1) / size) * size) + /* * Object which represents a vector * data points to the memory which must be interpreted according to the vector type @@ -85287,11 +85297,15 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); /* * Dumps vector on the console (used only for debugging) */ -void vectorDump (const Vector *v); +void vectorDump (const Vector *v); +void vectorF8Dump (const Vector *v); void vectorF32Dump (const Vector *v); void vectorF64Dump (const Vector *v); void vector1BitDump(const Vector *v); +void vectorF8GetParameters(const u8 *, int, float *, float *); +void vectorF8SetParameters(u8 *, int, float, float); + /* * Converts vector to the text representation and write the result to the sqlite3_context */ @@ -85302,15 +85316,17 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); +void vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF8SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +void vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceCos (const Vector *, const Vector *); +float vectorF8DistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); @@ -85323,6 +85339,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *); * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ float vectorDistanceL2 (const Vector *, const Vector *); +float vectorF8DistanceL2 (const Vector *, const Vector *); float vectorF32DistanceL2 (const Vector *, const Vector *); double vectorF64DistanceL2(const Vector *, const Vector *); @@ -85338,6 +85355,7 @@ void vectorSerializeWithMeta(sqlite3_context *, const Vector *); */ int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); +void vectorF8DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t); void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -85350,6 +85368,24 @@ void vectorConvert(const Vector *, Vector *); /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); +static inline unsigned serializeF32(unsigned char *pBuf, float value){ + u32 *p = (u32 *)&value; + pBuf[0] = *p & 0xFF; + pBuf[1] = (*p >> 8) & 0xFF; + pBuf[2] = (*p >> 16) & 0xFF; + pBuf[3] = (*p >> 24) & 0xFF; + return sizeof(float); +} + +static inline float deserializeF32(const unsigned char *pBuf){ + u32 value = 0; + value |= (u32)pBuf[0]; + value |= (u32)pBuf[1] << 8; + value |= (u32)pBuf[2] << 16; + value |= (u32)pBuf[3] << 24; + return *(float *)&value; +} + #if 0 } /* end of the 'extern "C"' block */ #endif @@ -211001,6 +211037,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(double); case VECTOR_TYPE_FLOAT1BIT: return (dims + 7) / 8; + case VECTOR_TYPE_FLOAT8: + return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */; default: assert(0); } @@ -211074,6 +211112,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF64DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT1BIT: return vector1BitDistanceHamming(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceCos(pVector1, pVector2); default: assert(0); } @@ -211087,6 +211127,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceL2(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceL2(pVector1, pVector2); + case VECTOR_TYPE_FLOAT8: + return vectorF8DistanceL2(pVector1, pVector2); default: assert(0); } @@ -211211,7 +211253,8 @@ static int vectorParseSqliteText( } static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){ - int nLeftoverBits; + int nTrailingBits; + int nTrailingBytes; if( nBlobSize % 2 == 0 ){ *pType = VECTOR_TYPE_FLOAT32; @@ -211224,26 +211267,34 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT if( *pType == VECTOR_TYPE_FLOAT32 ){ if( nBlobSize % 4 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(float); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT64 ){ if( nBlobSize % 8 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } *pDims = nBlobSize / sizeof(double); *pDataSize = nBlobSize; }else if( *pType == VECTOR_TYPE_FLOAT1BIT ){ if( nBlobSize == 0 || nBlobSize % 2 != 0 ){ - *pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); + *pzErrMsg = sqlite3_mprintf("vector: float1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize); return SQLITE_ERROR; } - nLeftoverBits = pBlob[nBlobSize - 1]; - *pDims = nBlobSize * 8 - nLeftoverBits; + nTrailingBits = pBlob[nBlobSize - 1]; + *pDims = nBlobSize * 8 - nTrailingBits; *pDataSize = (*pDims + 7) / 8; + }else if( *pType == VECTOR_TYPE_FLOAT8 ){ + if( nBlobSize < 2 || nBlobSize % 2 != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector: float8 vector blob length must be divisible by 2 and has at least 2 bytes (excluding 'type'-byte): length=%d", nBlobSize); + return SQLITE_ERROR; + } + nTrailingBytes = pBlob[nBlobSize - 1]; + *pDims = (nBlobSize - 2) - sizeof(float) - sizeof(float) - nTrailingBytes; + *pDataSize = nBlobSize - 2; }else{ *pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType); return SQLITE_ERROR; @@ -211289,6 +211340,9 @@ int vectorParseSqliteBlobWithType( case VECTOR_TYPE_FLOAT1BIT: vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize); return 0; + case VECTOR_TYPE_FLOAT8: + vectorF8DeserializeFromBlob(pVector, pBlob, nDataSize); + return 0; default: assert(0); } @@ -211387,6 +211441,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT1BIT: vector1BitDump(pVector); break; + case VECTOR_TYPE_FLOAT8: + vectorF8Dump(pVector); + break; default: assert(0); } @@ -211409,7 +211466,6 @@ void vectorMarshalToText( } static int vectorMetaSize(VectorType type, VectorDims dims){ - int nMetaSize = 0; int nDataSize; if( type == VECTOR_TYPE_FLOAT32 ){ return 0; @@ -211417,12 +211473,13 @@ static int vectorMetaSize(VectorType type, VectorDims dims){ return 1; }else if( type == VECTOR_TYPE_FLOAT1BIT ){ nDataSize = vectorDataSize(type, dims); - nMetaSize++; // one byte which specify amount of leftover bits - if( nDataSize % 2 == 0 ){ - nMetaSize++; // pad "leftover-bits" byte to the even length - } - nMetaSize++; // one byte for vector type - return nMetaSize; + // optional padding byte + "trailing-bits" byte + "vector-type" byte + return (nDataSize % 2 == 0 ? 1 : 0) + 1 + 1; + }else if( type == VECTOR_TYPE_FLOAT8 ){ + nDataSize = vectorDataSize(type, dims); + assert( nDataSize % 2 == 0 ); + /* padding byte + "trailing-bytes" byte + "vector-type" byte */ + return 1 + 1 + 1; }else{ assert( 0 ); } @@ -211440,6 +211497,15 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne assert( nBlobSize >= 3 ); pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT; pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims; + if( vectorMetaSize(pVector->type, pVector->dims) == 3 ){ + pBlob[nBlobSize - 3] = 0; + } + }else if( pVector->type == VECTOR_TYPE_FLOAT8 ){ + assert( nBlobSize % 2 == 1 ); + assert( nDataSize % 2 == 0 ); + assert( nBlobSize == nDataSize + 3 ); + pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT8; + pBlob[nBlobSize - 2] = ALIGN(pVector->dims, sizeof(float)) - pVector->dims; }else{ assert( 0 ); } @@ -211468,35 +211534,28 @@ void vectorSerializeWithMeta( return; } - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - vectorF32SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT64: - vectorF64SerializeToBlob(pVector, pBlob, nDataSize); - break; - case VECTOR_TYPE_FLOAT1BIT: - vector1BitSerializeToBlob(pVector, pBlob, nDataSize); - break; - default: - assert(0); - } + vectorSerializeToBlob(pVector, pBlob, nDataSize); vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize); sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } -size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ +void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT64: - return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + break; case VECTOR_TYPE_FLOAT1BIT: - return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); + break; + case VECTOR_TYPE_FLOAT8: + vectorF8SerializeToBlob(pVector, pBlob, nBlobSize); + break; default: assert(0); } - return 0; } void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ @@ -211602,6 +211661,110 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){ } } +static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *src; + float alpha, shift; + + float *dstF32; + double *dstF64; + u8 *dst1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pFrom->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pFrom->data, pFrom->dims, &alpha, &shift); + + src = pFrom->data; + if( pTo->type == VECTOR_TYPE_FLOAT32 ){ + dstF32 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF32[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT64 ){ + dstF64 = pTo->data; + for(i = 0; i < pFrom->dims; i++){ + dstF64[i] = alpha * src[i] + shift; + } + }else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){ + dst1Bit = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + dst1Bit[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( (alpha * src[i] + shift) > 0 ){ + dst1Bit[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert( 0 ); + } +} + +static inline int clip(float f, int minF, int maxF){ + if( f < minF ){ + return minF; + }else if( f > maxF ){ + return maxF; + } + return (int)(f + 0.5); +} + +#define MINMAX(i, value, minValue, maxValue) {if(i == 0){ minValue = (value); maxValue = (value);} else { minValue = MIN(minValue, (value)); maxValue = MAX(maxValue, (value)); }} + +static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){ + int i; + u8 *dst; + float alpha, shift; + float minF = 0, maxF = 0; + + float *srcF32; + double *srcF64; + u8 *src1Bit; + + assert( pFrom->dims == pTo->dims ); + assert( pFrom->type != pTo->type ); + assert( pTo->type == VECTOR_TYPE_FLOAT8 ); + + dst = pTo->data; + if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + srcF32 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF32[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF32[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ + srcF64 = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, srcF64[i], minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip((srcF64[i] - shift) / alpha, 0, 255); + } + }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ + src1Bit = pFrom->data; + for(i = 0; i < pFrom->dims; i++){ + MINMAX(i, ((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1, minF, maxF); + } + shift = minF; + alpha = (maxF - minF) / 255; + for(i = 0; i < pFrom->dims; i++){ + dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255); + } + }else{ + assert( 0 ); + } + vectorF8SetParameters(pTo->data, pTo->dims, alpha, shift); +} + + void vectorConvert(const Vector *pFrom, Vector *pTo){ assert( pFrom->dims == pTo->dims ); @@ -211610,12 +211773,16 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ return; } - if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ + if( pTo->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertToF8(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT32 ){ vectorConvertFromF32(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){ vectorConvertFromF64(pFrom, pTo); }else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){ vectorConvertFrom1Bit(pFrom, pTo); + }else if( pFrom->type == VECTOR_TYPE_FLOAT8 ){ + vectorConvertFromF8(pFrom, pTo); }else{ assert( 0 ); } @@ -211692,6 +211859,14 @@ static void vector64Func( vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64); } +static void vector8Func( + sqlite3_context *context, + int argc, + sqlite3_value **argv +){ + vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT8); +} + static void vector1BitFunc( sqlite3_context *context, int argc, @@ -211749,13 +211924,11 @@ static void vectorExtractFunc( } } -/* -** Implementation of vector_distance_cos(X, Y) function. -*/ -static void vectorDistanceCosFunc( +static void vectorDistanceFunc( sqlite3_context *context, int argc, - sqlite3_value **argv + sqlite3_value **argv, + float (*vectorDistance)(const Vector *pVector1, const Vector *pVector2) ){ char *pzErrMsg = NULL; Vector *pVector1 = NULL, *pVector2 = NULL; @@ -211775,13 +211948,19 @@ static void vectorDistanceCosFunc( goto out_free; } if( type1 != type2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same type: %d != %d", type1, type2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same type: %d != %d", type1, type2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } if( dims1 != dims2 ){ - pzErrMsg = sqlite3_mprintf("vector_distance_cos: vectors must have the same length: %d != %d", dims1, dims2); + pzErrMsg = sqlite3_mprintf("vector_distance: vectors must have the same length: %d != %d", dims1, dims2); + sqlite3_result_error(context, pzErrMsg, -1); + sqlite3_free(pzErrMsg); + goto out_free; + } + if( vectorDistance == vectorDistanceL2 && type1 == VECTOR_TYPE_FLOAT1BIT ){ + pzErrMsg = sqlite3_mprintf("vector_distance: l2 distance is not supported for float1bit vectors", dims1, dims2); sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211804,7 +211983,7 @@ static void vectorDistanceCosFunc( sqlite3_free(pzErrMsg); goto out_free; } - sqlite3_result_double(context, vectorDistanceCos(pVector1, pVector2)); + sqlite3_result_double(context, vectorDistance(pVector1, pVector2)); out_free: if( pVector2 ){ vectorFree(pVector2); @@ -211814,6 +211993,20 @@ static void vectorDistanceCosFunc( } } +/* +** Implementation of vector_distance_cos(X, Y) function. +*/ +static void vectorDistanceCosFunc(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceCos); +} + +/* +** Implementation of vector_distance_l2(X, Y) function. +*/ +static void vectorDistanceL2Func(sqlite3_context *context, int argc, sqlite3_value **argv){ + vectorDistanceFunc(context, argc, argv, vectorDistanceL2); +} + /* * Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb)); */ @@ -211831,8 +212024,10 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ FUNCTION(vector32, 1, 0, 0, vector32Func), FUNCTION(vector64, 1, 0, 0, vector64Func), FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc), + FUNCTION(vector8, 1, 0, 0, vector8Func), FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc), FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc), + FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func), FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx), }; @@ -213674,7 +213869,7 @@ void vector1BitDump(const Vector *pVec){ ** Utility routines for vector serialization and deserialization **************************************************************************/ -size_t vector1BitSerializeToBlob( +void vector1BitSerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -213685,12 +213880,11 @@ size_t vector1BitSerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < (pVector->dims + 7) / 8; i++){ pPtr[i] = elems[i]; } - return (pVector->dims + 7) / 8; } // [sum(map(int, bin(i)[2:])) for i in range(256)] @@ -213755,7 +213949,7 @@ void vector1BitDeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT1BIT ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= (pVector->dims + 7) / 8 ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); memcpy(elems, pBlob, (pVector->dims + 7) / 8); } @@ -213823,25 +214017,7 @@ static inline unsigned formatF32(float value, char *pBuf, int nBufSize){ return strlen(pBuf); } -static inline unsigned serializeF32(unsigned char *pBuf, float value){ - u32 *p = (u32 *)&value; - pBuf[0] = *p & 0xFF; - pBuf[1] = (*p >> 8) & 0xFF; - pBuf[2] = (*p >> 16) & 0xFF; - pBuf[3] = (*p >> 24) & 0xFF; - return sizeof(float); -} - -static inline float deserializeF32(const unsigned char *pBuf){ - u32 value = 0; - value |= (u32)pBuf[0]; - value |= (u32)pBuf[1] << 8; - value |= (u32)pBuf[2] << 16; - value |= (u32)pBuf[3] << 24; - return *(float *)&value; -} - -size_t vectorF32SerializeToBlob( +void vectorF32SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -213853,12 +214029,11 @@ size_t vectorF32SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ pPtr += serializeF32(pPtr, elems[i]); } - return sizeof(float) * pVector->dims; } #define SINGLE_FLOAT_CHAR_LIMIT 32 @@ -213944,7 +214119,7 @@ void vectorF32DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(float) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); @@ -214041,7 +214216,7 @@ static inline double deserializeF64(const unsigned char *pBuf){ return *(double *)&value; } -size_t vectorF64SerializeToBlob( +void vectorF64SerializeToBlob( const Vector *pVector, unsigned char *pBlob, size_t nBlobSize @@ -214052,12 +214227,11 @@ size_t vectorF64SerializeToBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for (i = 0; i < pVector->dims; i++) { pPtr += serializeF64(pPtr, elems[i]); } - return sizeof(double) * pVector->dims; } #define SINGLE_DOUBLE_CHAR_LIMIT 32 @@ -214143,7 +214317,7 @@ void vectorF64DeserializeFromBlob( assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize >= pVector->dims * sizeof(double) ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); @@ -214154,6 +214328,171 @@ void vectorF64DeserializeFromBlob( #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vectorfloat64.c ***************************************/ +/************** Begin file vectorfloat8.c ************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 8-bit (INT8) floating point vector format utilities. +** +** The idea is to replace vector [f_0, f_1, ... f_k] with quantized uint8 values [q_0, q_1, ..., q_k] in such a way that +** f_i = alpha * q_i + shift, when alpha and shift determined from all f_i values like that: +** alpha = (max(f) - min(f)) / 255, shift = min(f) +** +** This differs from uint8 quantization in neural-network as it usually take form of f_i = alpha * (q_i - z) conversion instead +** But, neural-network uint8 quantization is less generic and works better for distributions centered around zero (symmetric or not) +** In our implementation we want to handle more generic cases - so profits from neural-network-style quantization are not clear +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +void vectorF8GetParameters(const u8 *pData, int dims, float *pAlpha, float *pShift){ + pData = pData + ALIGN(dims, sizeof(float)); + *pAlpha = deserializeF32(pData); + *pShift = deserializeF32(pData + sizeof(*pAlpha)); +} + +void vectorF8SetParameters(u8 *pData, int dims, float alpha, float shift){ + pData = pData + ALIGN(dims, sizeof(float)); + serializeF32(pData, alpha); + serializeF32(pData + sizeof(alpha), shift); +} + +void vectorF8Dump(const Vector *pVec){ + u8 *elems = pVec->data; + float alpha, shift; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(pVec->data, pVec->dims, &alpha, &shift); + + printf("f8: ["); + for(i = 0; i < pVec->dims; i++){ + printf("%s%f", i == 0 ? "" : ", ", (float)elems[i] * alpha + shift); + } + printf("]\n"); +} + +void vectorF8SerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy(pBlob, pVector->data, pVector->dims); + + vectorF8GetParameters(pVector->data, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pBlob, pVector->dims, alpha, shift); +} + +float vectorF8DistanceCos(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + u32 sum1 = 0, sum2 = 0, sumsq1 = 0, sumsq2 = 0, doti = 0; + float dot = 0, norm1 = 0, norm2 = 0; + u8 *data1 = v1->data, *data2 = v2->data; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); + + /* + * (Ax + S)^2 = A^2 x^2 + 2AS x + S^2 -> we need to maintain 'sumsq' and 'sum' + * (A1x + S1) * (A2y + S2) = A1A2 xy + A1 S2 x + A2 S1 y + S1 S2 -> we need to maintain 'dot' and 'sum' again + */ + + for(i = 0; i < v1->dims; i++){ + sum1 += data1[i]; + sum2 += data2[i]; + sumsq1 += data1[i]*data1[i]; + sumsq2 += data2[i]*data2[i]; + doti += data1[i]*data2[i]; + } + + dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2 * v1->dims; + norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1 * v1->dims; + norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2 * v1->dims; + + return 1.0 - (dot / sqrt(norm1 * norm2)); +} + +float vectorF8DistanceL2(const Vector *v1, const Vector *v2){ + int i; + float alpha1, shift1, alpha2, shift2; + float sum = 0; + u8 *data1 = v1->data, *data2 = v2->data; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_FLOAT8 ); + assert( v2->type == VECTOR_TYPE_FLOAT8 ); + + vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1); + vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2); + + for(i = 0; i < v1->dims; i++){ + float d = (alpha1 * data1[i] + shift1) - (alpha2 * data2[i] + shift2); + sum += d*d; + } + return sqrt(sum); +} + +void vectorF8DeserializeFromBlob( + Vector *pVector, + const unsigned char *pBlob, + size_t nBlobSize +){ + float alpha, shift; + + assert( pVector->type == VECTOR_TYPE_FLOAT8 ); + assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) ); + + memcpy((u8*)pVector->data, (u8*)pBlob, ALIGN(pVector->dims, sizeof(float))); + + vectorF8GetParameters(pBlob, pVector->dims, &alpha, &shift); + vectorF8SetParameters(pVector->data, pVector->dims, alpha, shift); +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vectorfloat8.c ****************************************/ /************** Begin file vectorIndex.c *************************************/ /* ** 2024-03-18 @@ -214540,6 +214879,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = { { "F64_BLOB", VECTOR_TYPE_FLOAT64 }, { "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT }, { "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT }, + { "FLOAT8", VECTOR_TYPE_FLOAT8 }, + { "F8_BLOB", VECTOR_TYPE_FLOAT8 }, }; /* @@ -214559,6 +214900,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = { { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 }, { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, From 7f8a069721c7b110c24e240cd8253fda18761950 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 13 Aug 2024 16:27:37 -0400 Subject: [PATCH 098/121] libsql: release 0.5.1 --- Cargo.lock | 2 +- libsql/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e19e03e9a..c3b9c6b02b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3430,7 +3430,7 @@ dependencies = [ [[package]] name = "libsql" -version = "0.5.0" +version = "0.5.1" dependencies = [ "anyhow", "async-stream", diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index 3d65f71c73..aa78a1bf0c 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql" -version = "0.5.0" +version = "0.5.1" edition = "2021" description = "libSQL library: the main gateway for interacting with the database" repository = "https://github.com/tursodatabase/libsql" From 3da2a8bd2b1bc5e93354611f2758165d7ab1a8a6 Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 20:12:09 +0200 Subject: [PATCH 099/121] libsql: Add max_write_replication_index field to Database This field will store replication index of the latest write performed using any connection created with this Database object. This will allow a client to know what is a minimal replication index they have to use to see all the writes they performed so far. Signed-off-by: Piotr Jastrzebski --- libsql/src/database.rs | 17 ++++++++++++++--- libsql/src/database/builder.rs | 5 +++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/libsql/src/database.rs b/libsql/src/database.rs index e87def367d..9bfa727703 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -7,9 +7,10 @@ pub use builder::Builder; #[cfg(feature = "core")] pub use libsql_sys::{Cipher, EncryptionConfig}; -use std::fmt; - use crate::{Connection, Result}; +use std::fmt; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; cfg_core! { bitflags::bitflags! { @@ -76,6 +77,9 @@ impl fmt::Debug for DbType { /// not do much work until the [`Database::connect`] fn is called. pub struct Database { db_type: DbType, + /// The maximum replication index returned from a write performed using any connection created using this Database object. + #[allow(dead_code)] + max_write_replication_index: Arc, } cfg_core! { @@ -87,6 +91,7 @@ cfg_core! { Ok(Database { db_type: DbType::Memory { db }, + max_write_replication_index: Default::default(), }) } @@ -105,6 +110,7 @@ cfg_core! { flags, encryption_config: None, }, + max_write_replication_index: Default::default(), }) } } @@ -130,6 +136,7 @@ cfg_replication! { Ok(Database { db_type: DbType::Sync { db, encryption_config }, + max_write_replication_index: Default::default(), }) } @@ -191,6 +198,7 @@ cfg_replication! { Ok(Database { db_type: DbType::Sync { db, encryption_config }, + max_write_replication_index: Default::default(), }) } @@ -317,6 +325,7 @@ cfg_replication! { Ok(Database { db_type: DbType::Sync { db, encryption_config }, + max_write_replication_index: Default::default(), }) } @@ -372,7 +381,8 @@ cfg_replication! { DbType::Sync { db, .. } => { let path = db.path().to_string(); Ok(Database { - db_type: DbType::File { path, flags: OpenFlags::default(), encryption_config: None} + db_type: DbType::File { path, flags: OpenFlags::default(), encryption_config: None}, + max_write_replication_index: Default::default(), }) } t => Err(Error::FreezeNotSupported(format!("{:?}", t))) @@ -445,6 +455,7 @@ cfg_remote! { connector: crate::util::ConnectorService::new(svc), version, }, + max_write_replication_index: Default::default(), }) } } diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 8749b6452b..35cd93f899 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -135,6 +135,7 @@ cfg_core! { let db = crate::local::Database::open(":memory:", crate::OpenFlags::default())?; Database { db_type: DbType::Memory { db } , + max_write_replication_index: Default::default(), } } else { let path = self @@ -150,6 +151,7 @@ cfg_core! { flags: self.inner.flags, encryption_config: self.inner.encryption_config, }, + max_write_replication_index: Default::default(), } }; @@ -291,6 +293,7 @@ cfg_replication! { Ok(Database { db_type: DbType::Sync { db, encryption_config }, + max_write_replication_index: Default::default(), }) } } @@ -360,6 +363,7 @@ cfg_replication! { Ok(Database { db_type: DbType::Sync { db, encryption_config }, + max_write_replication_index: Default::default(), }) } } @@ -414,6 +418,7 @@ cfg_remote! { connector, version, }, + max_write_replication_index: Default::default(), }) } } From 8fb997e77c3c6fd66ab5e33f614dc09fce78a031 Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 20:26:49 +0200 Subject: [PATCH 100/121] libsql: Add max_write_replication_index method to Database Signed-off-by: Piotr Jastrzebski --- libsql/src/database.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 9bfa727703..0109cab4bd 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -388,6 +388,18 @@ cfg_replication! { t => Err(Error::FreezeNotSupported(format!("{:?}", t))) } } + + /// Get the maximum replication index returned from a write performed using any connection created using this Database object. + pub fn max_write_replication_index(&self) -> Option { + let index = self + .max_write_replication_index + .load(std::sync::atomic::Ordering::SeqCst); + if index == 0 { + None + } else { + Some(index) + } + } } } From 6c5193db6465069c9d4366259206335f2504cb2c Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 20:32:43 +0200 Subject: [PATCH 101/121] libsql: Add max_write_replication_index field to RemoteConnection Signed-off-by: Piotr Jastrzebski --- libsql/src/database.rs | 6 +++++- libsql/src/replication/connection.rs | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 0109cab4bd..44c606001e 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -575,7 +575,11 @@ impl Database { let local = LibsqlConnection { conn }; let writer = local.conn.new_connection_writer(); - let remote = crate::replication::RemoteConnection::new(local, writer); + let remote = crate::replication::RemoteConnection::new( + local, + writer, + self.max_write_replication_index.clone(), + ); let conn = std::sync::Arc::new(remote); Ok(Connection { conn }) diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index c720838798..78b41cb8e8 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -2,7 +2,7 @@ use std::str::FromStr; use std::sync::Arc; - +use std::sync::atomic::AtomicU64; use libsql_replication::rpc::proxy::{ describe_result, query_result::RowResult, Cond, DescribeResult, ExecuteResults, NotCond, OkCond, Positional, Query, ResultRows, State as RemoteState, Step, @@ -28,6 +28,8 @@ pub struct RemoteConnection { pub(self) local: LibsqlConnection, writer: Option, inner: Arc>, + #[allow(dead_code)] + max_write_replication_index: Arc, } #[derive(Default, Debug)] @@ -166,12 +168,13 @@ impl From for State { } impl RemoteConnection { - pub(crate) fn new(local: LibsqlConnection, writer: Option) -> Self { + pub(crate) fn new(local: LibsqlConnection, writer: Option, max_write_replication_index: Arc) -> Self { let state = Arc::new(Mutex::new(Inner::default())); Self { local, writer, inner: state, + max_write_replication_index, } } From 710ae01402fc39da93c02ba53c6d5a8293715cb3 Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 20:57:22 +0200 Subject: [PATCH 102/121] libsql: Update max_write_replication_index after every write Signed-off-by: Piotr Jastrzebski --- libsql/src/replication/connection.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index 78b41cb8e8..593bd634a1 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -28,7 +28,6 @@ pub struct RemoteConnection { pub(self) local: LibsqlConnection, writer: Option, inner: Arc>, - #[allow(dead_code)] max_write_replication_index: Arc, } @@ -178,6 +177,18 @@ impl RemoteConnection { } } + fn update_max_write_replication_index(&self, index: Option) { + if let Some(index) = index { + let mut current = self.max_write_replication_index.load(std::sync::atomic::Ordering::SeqCst); + while index > current { + match self.max_write_replication_index.compare_exchange(current, index, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst) { + Ok(_) => break, + Err(new_current) => current = new_current, + } + } + } + } + fn is_state_init(&self) -> bool { matches!(self.inner.lock().state, State::Init) } @@ -204,6 +215,8 @@ impl RemoteConnection { .into(); } + self.update_max_write_replication_index(res.current_frame_no); + if let Some(replicator) = writer.replicator() { replicator.sync_oneshot().await?; } @@ -229,6 +242,8 @@ impl RemoteConnection { .into(); } + self.update_max_write_replication_index(res.current_frame_no); + if let Some(replicator) = writer.replicator() { replicator.sync_oneshot().await?; } From 75d04cf61cb2225086720d202d1155bb66162a91 Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 21:23:50 +0200 Subject: [PATCH 103/121] libsql: Add Database::sync_until Signed-off-by: Piotr Jastrzebski --- libsql/src/database.rs | 10 ++++++++++ libsql/src/local/database.rs | 23 +++++++++++++++++++++++ libsql/src/replication/mod.rs | 4 ++-- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 44c606001e..d14cf2e42c 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -340,6 +340,16 @@ cfg_replication! { } } + /// Sync database from remote until it gets to a given replication_index or further, + /// and returns the committed frame_no after syncing, if applicable. + pub async fn sync_until(&self, replication_index: FrameNo) -> Result { + if let DbType::Sync { db, encryption_config: _ } = &self.db_type { + db.sync_until(replication_index).await + } else { + Err(Error::SyncNotSupported(format!("{:?}", self.db_type))) + } + } + /// Apply a set of frames to the database and returns the committed frame_no after syncing, if /// applicable. pub async fn sync_frames(&self, frames: crate::replication::Frames) -> Result> { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 2892d809cc..3453a777c9 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -277,6 +277,29 @@ impl Database { Ok(self.sync_oneshot().await?) } + #[cfg(feature = "replication")] + /// Sync with primary at least to a given replication index + pub async fn sync_until(&self, replication_index: FrameNo) -> Result { + if let Some(ctx) = &self.replication_ctx { + let mut frame_no: Option = ctx.replicator.committed_frame_no().await; + let mut frames_synced: usize = 0; + while frame_no.unwrap_or(0) < replication_index { + let res = ctx.replicator.sync_oneshot().await?; + frame_no = res.frame_no(); + frames_synced += res.frames_synced(); + } + Ok(crate::replication::Replicated { + frame_no, + frames_synced, + }) + } else { + Err(crate::errors::Error::Misuse( + "No replicator available. Use Database::with_replicator() to enable replication" + .to_string(), + )) + } + } + #[cfg(feature = "replication")] pub async fn sync_frames(&self, frames: Frames) -> Result> { if let Some(ref ctx) = self.replication_ctx { diff --git a/libsql/src/replication/mod.rs b/libsql/src/replication/mod.rs index 2f4e9b49c0..116839a54f 100644 --- a/libsql/src/replication/mod.rs +++ b/libsql/src/replication/mod.rs @@ -36,8 +36,8 @@ pub(crate) mod remote_client; #[derive(Debug)] pub struct Replicated { - frame_no: Option, - frames_synced: usize, + pub(crate) frame_no: Option, + pub(crate) frames_synced: usize, } impl Replicated { From 2cf34946e775e1e448dd0c302ae841848d2bd870 Mon Sep 17 00:00:00 2001 From: Piotr Jastrzebski Date: Tue, 13 Aug 2024 22:11:14 +0200 Subject: [PATCH 104/121] tests: Add checks for max_write_replication_index Signed-off-by: Piotr Jastrzebski --- libsql-server/tests/embedded_replica/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 78a46ad996..2d5f8c0de0 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -179,6 +179,8 @@ fn execute_batch() { conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) .await?; + assert_eq!(db.max_write_replication_index(), Some(1)); + let n = db.sync().await?.frame_no(); assert_eq!(n, Some(1)); @@ -231,6 +233,7 @@ fn stream() { conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) .await?; + assert_eq!(db.max_write_replication_index(), Some(1)); let n = db.sync().await?.frame_no(); assert_eq!(n, Some(1)); @@ -244,8 +247,10 @@ fn stream() { ", ) .await?; + let replication_index = db.max_write_replication_index(); - db.sync().await.unwrap(); + let synced_replication_index = db.sync().await.unwrap().frame_no(); + assert_eq!(synced_replication_index, replication_index); let rows = conn.query("select * from user", ()).await.unwrap(); From 045bef9941dcc68d8dea6d381057f4e4286f2152 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 13:00:47 +0400 Subject: [PATCH 105/121] replace parking_lot with tokio::sync::Mutex in some places --- libsql-server/src/connection/program.rs | 14 +++++++----- libsql-server/src/http/admin/mod.rs | 2 +- libsql-server/src/namespace/meta_store.rs | 26 +++++++++++------------ libsql-server/src/namespace/store.rs | 22 +++++++++---------- libsql-server/src/schema/db.rs | 3 +++ 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/libsql-server/src/connection/program.rs b/libsql-server/src/connection/program.rs index f128c7538f..febe15a9ed 100644 --- a/libsql-server/src/connection/program.rs +++ b/libsql-server/src/connection/program.rs @@ -363,11 +363,15 @@ pub fn check_program_auth( } StmtKind::Attach(ref ns) => { ctx.auth.has_right(ns, Permission::AttachRead)?; - if !ctx.meta_store.handle(ns.clone()).get().allow_attach { - return Err(Error::NotAuthorized(format!( - "Namespace `{ns}` doesn't allow attach" - ))); - } + return tokio::runtime::Handle::current().block_on(async { + if !ctx.meta_store.handle(ns.clone()).await.get().allow_attach { + return Err(Error::NotAuthorized(format!( + "Namespace `{ns}` doesn't allow attach" + ))); + } else { + Ok(()) + } + }); } StmtKind::Detach => (), } diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 683d67995e..37144aa196 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -331,7 +331,7 @@ async fn handle_create_namespace( )); } // TODO: move this check into meta store - if !app_state.namespaces.exists(&ns) { + if !app_state.namespaces.exists(&ns).await { return Err(Error::NamespaceDoesntExist(ns.to_string())); } diff --git a/libsql-server/src/namespace/meta_store.rs b/libsql-server/src/namespace/meta_store.rs index 599dab9360..f66575dde4 100644 --- a/libsql-server/src/namespace/meta_store.rs +++ b/libsql-server/src/namespace/meta_store.rs @@ -70,8 +70,8 @@ struct MetaStoreInner { // TODO(lucio): Use a concurrent hashmap so we don't block connection creation // when we are updating the config. The config si already synced via the watch // channel. - configs: Mutex>>, - conn: Mutex, + configs: tokio::sync::Mutex>>, + conn: tokio::sync::Mutex, wal_manager: MetaStoreWalManager, } @@ -313,7 +313,7 @@ fn process(msg: ChangeMsg, inner: Arc) { } else { Ok(()) }; - let mut configs = inner.configs.lock(); + let mut configs = inner.configs.blocking_lock(); if let Some(config_watch) = configs.get_mut(&namespace) { let new_version = config_watch.borrow().version.wrapping_add(1); @@ -330,7 +330,7 @@ fn process(msg: ChangeMsg, inner: Arc) { let _ = ret_chan.send(ret); } else { let ret = if flush { - let mut configs = inner.configs.lock(); + let mut configs = inner.configs.blocking_lock(); if let Some(config_watch) = configs.get_mut(&namespace) { let config = config_watch.subscribe().borrow().clone(); try_process(&inner, &namespace, &config.config) @@ -351,7 +351,7 @@ fn try_process( ) -> Result<()> { let config_encoded = metadata::DatabaseConfig::from(&*config).encode_to_vec(); - let mut conn = inner.conn.lock(); + let mut conn = inner.conn.blocking_lock(); if let Some(schema) = config.shared_schema_name.as_ref() { let tx = conn.transaction()?; if let Some(ref schema) = config.shared_schema_name { @@ -470,11 +470,11 @@ impl MetaStore { Ok(Self { changes_tx, inner }) } - pub fn handle(&self, namespace: NamespaceName) -> MetaStoreHandle { + pub async fn handle(&self, namespace: NamespaceName) -> MetaStoreHandle { tracing::debug!("getting meta store handle"); let change_tx = self.changes_tx.clone(); - let mut configs = self.inner.configs.lock(); + let mut configs = self.inner.configs.lock().await; let sender = configs.entry(namespace.clone()).or_insert_with(|| { // TODO(lucio): if no entry exists we need to ensure we send the update to // the bg channel. @@ -495,11 +495,11 @@ impl MetaStore { pub fn remove(&self, namespace: NamespaceName) -> Result>> { tracing::debug!("removing namespace `{}` from meta store", namespace); - let mut configs = self.inner.configs.lock(); + let mut configs = self.inner.configs.blocking_lock(); let r = if let Some(sender) = configs.get(&namespace) { tracing::debug!("removed namespace `{}` from meta store", namespace); let config = sender.borrow().clone(); - let mut conn = self.inner.conn.lock(); + let mut conn = self.inner.conn.blocking_lock(); let tx = conn.transaction()?; if config.config.is_shared_schema { if crate::schema::db::schema_has_linked_dbs(&tx, &namespace)? { @@ -535,8 +535,8 @@ impl MetaStore { // TODO: we need to either make sure that the metastore is restored // before we start accepting connections or we need to contact bottomless // here to check if a namespace exists. Preferably the former. - pub fn exists(&self, namespace: &NamespaceName) -> bool { - self.inner.configs.lock().contains_key(namespace) + pub async fn exists(&self, namespace: &NamespaceName) -> bool { + self.inner.configs.lock().await.contains_key(namespace) } pub(crate) async fn shutdown(&self) -> crate::Result<()> { @@ -559,7 +559,7 @@ impl MetaStore { ) -> crate::Result { let inner = self.inner.clone(); let summary = tokio::task::spawn_blocking(move || { - let mut conn = inner.conn.lock(); + let mut conn = inner.conn.blocking_lock(); crate::schema::get_migrations_summary(&mut conn, schema) }) .await @@ -574,7 +574,7 @@ impl MetaStore { ) -> crate::Result> { let inner = self.inner.clone(); let details = tokio::task::spawn_blocking(move || { - let mut conn = inner.conn.lock(); + let mut conn = inner.conn.blocking_lock(); crate::schema::get_migration_details(&mut conn, schema, job_id) }) .await diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index a78e4f59b0..f9f614fc77 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -99,8 +99,8 @@ impl NamespaceStore { }) } - pub fn exists(&self, namespace: &NamespaceName) -> bool { - self.inner.metadata.exists(namespace) + pub async fn exists(&self, namespace: &NamespaceName) -> bool { + self.inner.metadata.exists(namespace).await } pub async fn destroy(&self, namespace: NamespaceName, prune_all: bool) -> crate::Result<()> { @@ -173,7 +173,7 @@ impl NamespaceStore { ns.destroy().await?; } - let db_config = self.inner.metadata.handle(namespace.clone()); + let db_config = self.inner.metadata.handle(namespace.clone()).await; // destroy on-disk database self.cleanup( &namespace, @@ -226,7 +226,7 @@ impl NamespaceStore { } // check that the source namespace exists - if !self.inner.metadata.exists(&from) { + if !self.inner.metadata.exists(&from).await { return Err(crate::error::Error::NamespaceDoesntExist(from.to_string())); } @@ -241,11 +241,11 @@ impl NamespaceStore { } // FIXME: we could potentially delete the namespace while trying to fork it - if !self.inner.metadata.exists(&from) { + if !self.inner.metadata.exists(&from).await { return Err(crate::Error::NamespaceDoesntExist(from.to_string())); } - let from_config = self.inner.metadata.handle(from.clone()); + let from_config = self.inner.metadata.handle(from.clone()).await; let from_entry = self .load_namespace(&from, from_config.clone(), RestoreOption::Latest) .await?; @@ -280,7 +280,7 @@ impl NamespaceStore { should_delete: true, }; - let handle = self.inner.metadata.handle(to.clone()); + let handle = self.inner.metadata.handle(to.clone()).await; handle .store_and_maybe_flush(Some(to_config.into()), false) .await?; @@ -328,7 +328,7 @@ impl NamespaceStore { Fun: FnOnce(&Namespace) -> R, { if namespace != NamespaceName::default() - && !self.inner.metadata.exists(&namespace) + && !self.inner.metadata.exists(&namespace).await && !self.inner.allow_lazy_creation { return Err(Error::NamespaceDoesntExist(namespace.to_string())); @@ -346,7 +346,7 @@ impl NamespaceStore { } }; - let handle = self.inner.metadata.handle(namespace.to_owned()); + let handle = self.inner.metadata.handle(namespace.to_owned()).await; f(self .load_namespace(&namespace, handle, RestoreOption::Latest) .await?) @@ -440,12 +440,12 @@ impl NamespaceStore { // FIXME: move the default namespace check out of this function. if self.inner.allow_lazy_creation || namespace == NamespaceName::default() { tracing::trace!("auto-creating the namespace"); - } else if self.inner.metadata.exists(&namespace) { + } else if self.inner.metadata.exists(&namespace).await { return Err(Error::NamespaceAlreadyExist(namespace.to_string())); } let db_config = Arc::new(db_config); - let handle = self.inner.metadata.handle(namespace.clone()); + let handle = self.inner.metadata.handle(namespace.clone()).await; tracing::debug!("storing db config"); handle.store(db_config).await?; tracing::debug!("completed storing db config, loading namespace"); diff --git a/libsql-server/src/schema/db.rs b/libsql-server/src/schema/db.rs index 7efb2efaa8..f644b2420f 100644 --- a/libsql-server/src/schema/db.rs +++ b/libsql-server/src/schema/db.rs @@ -482,6 +482,7 @@ mod test { async fn register_schema(meta_store: &MetaStore, schema: &'static str) { meta_store .handle(schema.into()) + .await .store(DatabaseConfig { is_shared_schema: true, ..Default::default() @@ -497,6 +498,7 @@ mod test { ) -> crate::Result<()> { meta_store .handle(name.into()) + .await .store(DatabaseConfig { shared_schema_name: Some(schema.into()), ..Default::default() @@ -561,6 +563,7 @@ mod test { // necessary checks beforehand, and return a nice error message. assert!(meta_store .handle("ns1".into()) + .await .store(DatabaseConfig { shared_schema_name: Some("schema1".into()), ..Default::default() From 768998fee29286ff7a6e02a37ec6a48288c7991c Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 15:03:12 +0400 Subject: [PATCH 106/121] make check_program_auth async too --- libsql-server/src/connection/libsql.rs | 3 ++- libsql-server/src/connection/program.rs | 16 ++++++---------- libsql-server/src/database/schema.rs | 2 +- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index 9896164e55..2e9183567a 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -415,7 +415,8 @@ where PROGRAM_EXEC_COUNT.increment(1); - check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; + let config = self.inner.lock().config_store.get(); + check_program_auth(&ctx, &pgm, &config).await?; // create the bomb right before spawning the blocking task. let mut bomb = Bomb { diff --git a/libsql-server/src/connection/program.rs b/libsql-server/src/connection/program.rs index febe15a9ed..29ef408b68 100644 --- a/libsql-server/src/connection/program.rs +++ b/libsql-server/src/connection/program.rs @@ -341,7 +341,7 @@ fn value_size(val: &rusqlite::types::ValueRef) -> usize { } } -pub fn check_program_auth( +pub async fn check_program_auth( ctx: &RequestContext, pgm: &Program, config: &DatabaseConfig, @@ -363,15 +363,11 @@ pub fn check_program_auth( } StmtKind::Attach(ref ns) => { ctx.auth.has_right(ns, Permission::AttachRead)?; - return tokio::runtime::Handle::current().block_on(async { - if !ctx.meta_store.handle(ns.clone()).await.get().allow_attach { - return Err(Error::NotAuthorized(format!( - "Namespace `{ns}` doesn't allow attach" - ))); - } else { - Ok(()) - } - }); + if !ctx.meta_store.handle(ns.clone()).await.get().allow_attach { + return Err(Error::NotAuthorized(format!( + "Namespace `{ns}` doesn't allow attach" + ))); + } } StmtKind::Detach => (), } diff --git a/libsql-server/src/database/schema.rs b/libsql-server/src/database/schema.rs index 0b9674bd60..4195f4603e 100644 --- a/libsql-server/src/database/schema.rs +++ b/libsql-server/src/database/schema.rs @@ -50,7 +50,7 @@ impl crate::connection::Connection for SchemaConnection { res } else { - check_program_auth(&ctx, &migration, &self.config.get())?; + check_program_auth(&ctx, &migration, &self.config.get()).await?; let connection = self.connection.clone(); validate_migration(&mut migration)?; let migration = Arc::new(migration); From 47bf462c89966df18d87b56e680b8273c05e6a67 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 15:03:26 +0400 Subject: [PATCH 107/121] take conn lock before configs lock --- libsql-server/src/namespace/meta_store.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/namespace/meta_store.rs b/libsql-server/src/namespace/meta_store.rs index f66575dde4..283eaf36eb 100644 --- a/libsql-server/src/namespace/meta_store.rs +++ b/libsql-server/src/namespace/meta_store.rs @@ -495,11 +495,18 @@ impl MetaStore { pub fn remove(&self, namespace: NamespaceName) -> Result>> { tracing::debug!("removing namespace `{}` from meta store", namespace); + // "configs" lock can be used in both async and sync contexts while "conn" lock always used + // in blocking context + // + // so, we better to acquire "conn" lock first in order to prevent situation when "configs" + // lock is taken but "conn" lock is not free (so, we potentially will block async tasks for + // indefinite amount of time while "conn" lock will be acquired by other thread) + let mut conn = self.inner.conn.blocking_lock(); + let mut configs = self.inner.configs.blocking_lock(); let r = if let Some(sender) = configs.get(&namespace) { tracing::debug!("removed namespace `{}` from meta store", namespace); let config = sender.borrow().clone(); - let mut conn = self.inner.conn.blocking_lock(); let tx = conn.transaction()?; if config.config.is_shared_schema { if crate::schema::db::schema_has_linked_dbs(&tx, &namespace)? { From 119dc57073785767d05313aea27c52a18cad9399 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 16:09:12 +0400 Subject: [PATCH 108/121] expose basic tokio runtime metrics --- libsql-server/src/http/admin/mod.rs | 25 +++++++++++++ libsql-server/src/metrics.rs | 54 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 683d67995e..24c516ec60 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -93,9 +93,34 @@ where tokio::task::spawn(async move { loop { + let runtime = tokio::runtime::Handle::current(); + let metrics = runtime.metrics(); + crate::metrics::TOKIO_RUNTIME_BLOCKING_QUEUE_DEPTH + .set(metrics.blocking_queue_depth() as f64); + crate::metrics::TOKIO_RUNTIME_INJECTION_QUEUE_DEPTH + .set(metrics.injection_queue_depth() as f64); + crate::metrics::TOKIO_RUNTIME_NUM_BLOCKING_THREADS + .set(metrics.num_blocking_threads() as f64); + crate::metrics::TOKIO_RUNTIME_NUM_IDLE_BLOCKING_THREADS + .set(metrics.num_idle_blocking_threads() as f64); + crate::metrics::TOKIO_RUNTIME_NUM_WORKERS.set(metrics.num_workers() as f64); + + crate::metrics::TOKIO_RUNTIME_IO_DRIVER_FD_DEREGISTERED_COUNT + .absolute(metrics.io_driver_fd_deregistered_count() as u64); + crate::metrics::TOKIO_RUNTIME_IO_DRIVER_FD_REGISTERED_COUNT + .absolute(metrics.io_driver_fd_registered_count() as u64); + crate::metrics::TOKIO_RUNTIME_IO_DRIVER_READY_COUNT + .absolute(metrics.io_driver_ready_count() as u64); + crate::metrics::TOKIO_RUNTIME_REMOTE_SCHEDULE_COUNT + .absolute(metrics.remote_schedule_count() as u64); tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + }); + tokio::task::spawn(async move { + loop { crate::metrics::SERVER_COUNT.set(1.0); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; } }); diff --git a/libsql-server/src/metrics.rs b/libsql-server/src/metrics.rs index 1ac97435b3..bb9c049fa7 100644 --- a/libsql-server/src/metrics.rs +++ b/libsql-server/src/metrics.rs @@ -158,3 +158,57 @@ pub static QUERY_CANCELED: Lazy = Lazy::new(|| { describe_counter!(NAME, "Number of canceled queries"); register_counter!(NAME) }); + +pub static TOKIO_RUNTIME_BLOCKING_QUEUE_DEPTH: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_blocking_queue_depth"; + describe_gauge!(NAME, "tokio runtime blocking_queue_depth"); + register_gauge!(NAME) +}); + +pub static TOKIO_RUNTIME_INJECTION_QUEUE_DEPTH: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_injection_queue_depth"; + describe_gauge!(NAME, "tokio runtime injection_queue_depth"); + register_gauge!(NAME) +}); + +pub static TOKIO_RUNTIME_NUM_BLOCKING_THREADS: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_num_blocking_threads"; + describe_gauge!(NAME, "tokio runtime num_blocking_threads"); + register_gauge!(NAME) +}); + +pub static TOKIO_RUNTIME_NUM_IDLE_BLOCKING_THREADS: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_num_idle_blocking_threads"; + describe_gauge!(NAME, "tokio runtime num_idle_blocking_threads"); + register_gauge!(NAME) +}); + +pub static TOKIO_RUNTIME_NUM_WORKERS: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_num_workers"; + describe_gauge!(NAME, "tokio runtime num_workers"); + register_gauge!(NAME) +}); + +pub static TOKIO_RUNTIME_IO_DRIVER_FD_DEREGISTERED_COUNT: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_io_driver_fd_deregistered_count"; + describe_counter!(NAME, "tokio runtime io_driver_fd_deregistered_count"); + register_counter!(NAME) +}); + +pub static TOKIO_RUNTIME_IO_DRIVER_FD_REGISTERED_COUNT: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_io_driver_fd_registered_count"; + describe_counter!(NAME, "tokio runtime io_driver_fd_registered_count"); + register_counter!(NAME) +}); + +pub static TOKIO_RUNTIME_IO_DRIVER_READY_COUNT: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_io_driver_ready_count"; + describe_counter!(NAME, "tokio runtime io_driver_ready_count"); + register_counter!(NAME) +}); + +pub static TOKIO_RUNTIME_REMOTE_SCHEDULE_COUNT: Lazy = Lazy::new(|| { + const NAME: &str = "tokio_runtime_remote_schedule_count"; + describe_gauge!(NAME, "tokio runtime remote_schedule_count"); + register_counter!(NAME) +}); From 0dcdb176f467d9ac3936294477927238f099b079 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 21:41:55 +0400 Subject: [PATCH 109/121] add tokio_unstable cfg in Cargo.toml and GH actions --- .github/workflows/nemesis.yml | 2 +- .github/workflows/rust.yml | 10 +++++----- libsql-server/Cargo.toml | 4 ++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/nemesis.yml b/.github/workflows/nemesis.yml index 9cfdbc39b0..090f2626b8 100644 --- a/.github/workflows/nemesis.yml +++ b/.github/workflows/nemesis.yml @@ -18,7 +18,7 @@ jobs: if: github.repository == 'tursodatabase/libsql' name: Run Nemesis Tests env: - RUSTFLAGS: -D warnings + RUSTFLAGS: -D warnings --cfg tokio_unstable steps: - uses: hecrj/setup-rust-action@v2 diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1903c0baff..7036c6cf2f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest name: Run Checks env: - RUSTFLAGS: -D warnings + RUSTFLAGS: -D warnings --cfg tokio_unstable steps: - uses: hecrj/setup-rust-action@v2 @@ -80,15 +80,15 @@ jobs: - uses: taiki-e/install-action@cargo-udeps - uses: Swatinem/rust-cache@v2 - run: cargo +nightly hack udeps -p libsql --each-feature - - run: RUSTFLAGS="-D warnings" cargo check -p libsql --no-default-features --features core - - run: RUSTFLAGS="-D warnings" cargo check -p libsql --no-default-features --features replication - - run: RUSTFLAGS="-D warnings" cargo check -p libsql --no-default-features --features remote + - run: RUSTFLAGS="-D warnings --cfg tokio_unstable" cargo check -p libsql --no-default-features --features core + - run: RUSTFLAGS="-D warnings --cfg tokio_unstable" cargo check -p libsql --no-default-features --features replication + - run: RUSTFLAGS="-D warnings --cfg tokio_unstable" cargo check -p libsql --no-default-features --features remote test: runs-on: ubuntu-latest name: Run Tests env: - RUSTFLAGS: -D warnings + RUSTFLAGS: -D warnings --cfg tokio_unstable steps: - uses: hecrj/setup-rust-action@v2 diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 934a400786..b22b9f1458 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -8,6 +8,9 @@ default-run = "sqld" name = "sqld" path = "src/main.rs" +[build] +rustflags = ["--cfg", "tokio_unstable"] + [dependencies] anyhow = "1.0.66" async-lock = "2.6.0" @@ -116,6 +119,7 @@ vergen = { version = "8", features = ["build", "git", "gitcl"] } [features] default = [] +tokio-metrics = [] debug-tools = ["console-subscriber", "rusqlite/trace", "tokio/tracing"] wasm-udfs = ["rusqlite/libsql-wasm-experimental"] unix-excl-vfs = ["libsql-sys/unix-excl-vfs"] From 4ebff7fdd5b0dd9e6fa250c5c35cd9b27778536f Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 15 Aug 2024 21:58:47 +0400 Subject: [PATCH 110/121] remove feature --- libsql-server/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index b22b9f1458..2db9330fa6 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -119,7 +119,6 @@ vergen = { version = "8", features = ["build", "git", "gitcl"] } [features] default = [] -tokio-metrics = [] debug-tools = ["console-subscriber", "rusqlite/trace", "tokio/tracing"] wasm-udfs = ["rusqlite/libsql-wasm-experimental"] unix-excl-vfs = ["libsql-sys/unix-excl-vfs"] From f6fe40df5171614d492369ec69e16a58e88e15f0 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 16 Aug 2024 00:53:54 +0400 Subject: [PATCH 111/121] use single task to report bunch of metrics --- libsql-server/src/http/admin/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 24c516ec60..7f34b77359 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -113,12 +113,7 @@ where .absolute(metrics.io_driver_ready_count() as u64); crate::metrics::TOKIO_RUNTIME_REMOTE_SCHEDULE_COUNT .absolute(metrics.remote_schedule_count() as u64); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - } - }); - tokio::task::spawn(async move { - loop { crate::metrics::SERVER_COUNT.set(1.0); tokio::time::sleep(std::time::Duration::from_secs(1)).await; } From baadddd9bc107283d916e582d06282acf27e6799 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 16 Aug 2024 18:44:39 +0400 Subject: [PATCH 112/121] remove unused section from cargo.toml - it should be in the config.toml - but we already have it in the workspace config --- libsql-server/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 2db9330fa6..934a400786 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -8,9 +8,6 @@ default-run = "sqld" name = "sqld" path = "src/main.rs" -[build] -rustflags = ["--cfg", "tokio_unstable"] - [dependencies] anyhow = "1.0.66" async-lock = "2.6.0" From 9a7dd75bb3883e45186693b189935c312a6f2b60 Mon Sep 17 00:00:00 2001 From: Athos Couto Date: Fri, 16 Aug 2024 16:50:19 -0300 Subject: [PATCH 113/121] Bump sqld to v0.24.19 --- Cargo.lock | 2 +- libsql-server/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3b9c6b02b..483f703ed6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3532,7 +3532,7 @@ dependencies = [ [[package]] name = "libsql-server" -version = "0.24.18" +version = "0.24.19" dependencies = [ "aes", "anyhow", diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 934a400786..bba843ce4c 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-server" -version = "0.24.18" +version = "0.24.19" edition = "2021" default-run = "sqld" From 1bbfed668073eb9a61e13b244b0c2bba0f5e1eb6 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sat, 17 Aug 2024 01:08:24 +0400 Subject: [PATCH 114/121] set RUSTFLAGS for cargo dist in libsql-server-release GA --- .github/workflows/libsql-server-release.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index 919d6fa179..780eb7e21a 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -54,6 +54,7 @@ jobs: publishing: ${{ !github.event.pull_request }} env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + RUSTFLAGS: --cfg tokio_unstable steps: - uses: actions/checkout@v4 with: @@ -104,6 +105,7 @@ jobs: env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json + RUSTFLAGS: --cfg tokio_unstable steps: - name: enable windows longpaths run: | @@ -161,6 +163,7 @@ jobs: env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json + RUSTFLAGS: --cfg tokio_unstable steps: - uses: actions/checkout@v4 with: @@ -204,6 +207,7 @@ jobs: if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }} env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + RUSTFLAGS: --cfg tokio_unstable runs-on: "ubuntu-20.04" outputs: val: ${{ steps.host.outputs.manifest }} From 3abd43a36b8166f847be9e0017b501dff36ab596 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:10:22 +0400 Subject: [PATCH 115/121] add missed repository key for libsql-server package --- libsql-server/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index bba843ce4c..7dd5c965a2 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -3,6 +3,7 @@ name = "libsql-server" version = "0.24.19" edition = "2021" default-run = "sqld" +repository = "https://github.com/tursodatabase/libsql" [[bin]] name = "sqld" From 8a37813d675466566dd6a5e8fdf3a616045a2db7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:10:45 +0400 Subject: [PATCH 116/121] upgrade cargo dist 0.14.1 -> 0.21.0 in order to use "github-build-setup" setting --- .../libsql-server-release-build-setup.yml | 3 + .github/workflows/libsql-server-release.yml | 83 ++++++++++--------- Cargo.toml | 8 +- 3 files changed, 55 insertions(+), 39 deletions(-) create mode 100644 .github/workflows/libsql-server-release-build-setup.yml diff --git a/.github/workflows/libsql-server-release-build-setup.yml b/.github/workflows/libsql-server-release-build-setup.yml new file mode 100644 index 0000000000..dbe70309b9 --- /dev/null +++ b/.github/workflows/libsql-server-release-build-setup.yml @@ -0,0 +1,3 @@ +- name: Prepare env vars + env: + RUSTFLAGS: --cfg tokio_unstable diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index 780eb7e21a..56a2dbf0a0 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -12,9 +12,8 @@ # title/body based on your changelogs. name: Release - permissions: - contents: write + "contents": "write" # This task will run whenever you push a git tag that looks like a version # like "1.0.0", "v0.1.0-prerelease.1", "my-app/0.1.0", "releases/v1.0.0", etc. @@ -38,15 +37,15 @@ permissions: # If there's a prerelease-style suffix to the version, then the release(s) # will be marked as a prerelease. on: + pull_request: push: tags: - 'libsql-server**[0-9]+.[0-9]+.[0-9]+*' - pull_request: jobs: # Run 'cargo dist plan' (or host) to determine what tasks we need to do plan: - runs-on: ubuntu-latest + runs-on: "ubuntu-20.04" outputs: val: ${{ steps.plan.outputs.manifest }} tag: ${{ !github.event.pull_request && github.ref_name || '' }} @@ -54,7 +53,6 @@ jobs: publishing: ${{ !github.event.pull_request }} env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - RUSTFLAGS: --cfg tokio_unstable steps: - uses: actions/checkout@v4 with: @@ -63,7 +61,12 @@ jobs: # we specify bash to get pipefail; it guards against the `curl` command # failing. otherwise `sh` won't catch that `curl` returned non-0 shell: bash - run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.14.1/cargo-dist-installer.sh | sh" + run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.21.0/cargo-dist-installer.sh | sh" + - name: Cache cargo-dist + uses: actions/upload-artifact@v4 + with: + name: cargo-dist-cache + path: ~/.cargo/bin/cargo-dist # sure would be cool if github gave us proper conditionals... # so here's a doubly-nested ternary-via-truthiness to try to provide the best possible # functionality based on whether this is a pull_request, and whether it's from a fork. @@ -105,7 +108,6 @@ jobs: env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json - RUSTFLAGS: --cfg tokio_unstable steps: - name: enable windows longpaths run: | @@ -113,9 +115,9 @@ jobs: - uses: actions/checkout@v4 with: submodules: recursive - - uses: swatinem/rust-cache@v2 - with: - key: ${{ join(matrix.targets, '-') }} + - name: "Prepare env vars" + env: + "RUSTFLAGS": "--cfg tokio_unstable" - name: Install cargo-dist run: ${{ matrix.install_dist }} # Get the dist-manifest @@ -163,14 +165,16 @@ jobs: env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json - RUSTFLAGS: --cfg tokio_unstable steps: - uses: actions/checkout@v4 with: submodules: recursive - - name: Install cargo-dist - shell: bash - run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.14.1/cargo-dist-installer.sh | sh" + - name: Install cached cargo-dist + uses: actions/download-artifact@v4 + with: + name: cargo-dist-cache + path: ~/.cargo/bin/ + - run: chmod +x ~/.cargo/bin/cargo-dist # Get all the local artifacts for the global tasks to use (for e.g. checksums) - name: Fetch local artifacts uses: actions/download-artifact@v4 @@ -207,7 +211,6 @@ jobs: if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }} env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - RUSTFLAGS: --cfg tokio_unstable runs-on: "ubuntu-20.04" outputs: val: ${{ steps.host.outputs.manifest }} @@ -215,8 +218,12 @@ jobs: - uses: actions/checkout@v4 with: submodules: recursive - - name: Install cargo-dist - run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.14.1/cargo-dist-installer.sh | sh" + - name: Install cached cargo-dist + uses: actions/download-artifact@v4 + with: + name: cargo-dist-cache + path: ~/.cargo/bin/ + - run: chmod +x ~/.cargo/bin/cargo-dist # Fetch artifacts from scratch-storage - name: Fetch artifacts uses: actions/download-artifact@v4 @@ -224,7 +231,6 @@ jobs: pattern: artifacts-* path: target/distrib/ merge-multiple: true - # This is a harmless no-op for GitHub Releases, hosting for that happens in "announce" - id: host shell: bash run: | @@ -238,6 +244,28 @@ jobs: # Overwrite the previous copy name: artifacts-dist-manifest path: dist-manifest.json + # Create a GitHub Release while uploading all files to it + - name: "Download GitHub Artifacts" + uses: actions/download-artifact@v4 + with: + pattern: artifacts-* + path: artifacts + merge-multiple: true + - name: Cleanup + run: | + # Remove the granular manifests + rm -f artifacts/*-dist-manifest.json + - name: Create GitHub Release + env: + PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}" + ANNOUNCEMENT_TITLE: "${{ fromJson(steps.host.outputs.manifest).announcement_title }}" + ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}" + RELEASE_COMMIT: "${{ github.sha }}" + run: | + # Write and read notes from a file to avoid quoting breaking things + echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt + + gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/* publish-homebrew-formula: needs: @@ -279,7 +307,6 @@ jobs: done git push - # Create a GitHub Release while uploading all files to it announce: needs: - plan @@ -296,21 +323,3 @@ jobs: - uses: actions/checkout@v4 with: submodules: recursive - - name: "Download GitHub Artifacts" - uses: actions/download-artifact@v4 - with: - pattern: artifacts-* - path: artifacts - merge-multiple: true - - name: Cleanup - run: | - # Remove the granular manifests - rm -f artifacts/*-dist-manifest.json - - name: Create GitHub Release - uses: ncipollo/release-action@v1 - with: - tag: ${{ needs.plan.outputs.tag }} - name: ${{ fromJson(needs.host.outputs.val).announcement_title }} - body: ${{ fromJson(needs.host.outputs.val).announcement_github_body }} - prerelease: ${{ fromJson(needs.host.outputs.val).announcement_is_prerelease }} - artifacts: "artifacts/*" diff --git a/Cargo.toml b/Cargo.toml index 685f14964f..b29e9ce49a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ zerocopy = { version = "0.7.32", features = ["derive", "alloc"] } # Config for 'cargo dist' [workspace.metadata.dist] # The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax) -cargo-dist-version = "0.14.1" +cargo-dist-version = "0.21.0" # CI backends to support ci = "github" # The installers to generate for each app @@ -65,12 +65,16 @@ targets = ["aarch64-apple-darwin", "aarch64-unknown-linux-gnu", "x86_64-apple-da publish-jobs = ["homebrew"] # Whether cargo-dist should create a Github Release or use an existing draft create-release = true -# Publish jobs to run in CI +# Which actions to run on pull requests pr-run-mode = "plan" # A prefix git tags must include for cargo-dist to care about them tag-namespace = "libsql-server" # Whether to install an updater program install-updater = false +# additional setup steps +github-build-setup = "libsql-server-release-build-setup.yml" +# Path that installers should place binaries in +install-path = "CARGO_HOME" [workspace.metadata.dist.github-custom-runners] aarch64-apple-darwin = "macos-14" From 6648908206ab210d54979cce318c0ae1b52fb0d8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:17:43 +0400 Subject: [PATCH 117/121] move build-setup steps to the templates dir --- .../libsql-server-release-build-setup.yml | 1 + Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) rename .github/{workflows => templates}/libsql-server-release-build-setup.yml (70%) diff --git a/.github/workflows/libsql-server-release-build-setup.yml b/.github/templates/libsql-server-release-build-setup.yml similarity index 70% rename from .github/workflows/libsql-server-release-build-setup.yml rename to .github/templates/libsql-server-release-build-setup.yml index dbe70309b9..8c3b8a57bb 100644 --- a/.github/workflows/libsql-server-release-build-setup.yml +++ b/.github/templates/libsql-server-release-build-setup.yml @@ -1,3 +1,4 @@ - name: Prepare env vars + run: "echo setup env vars" env: RUSTFLAGS: --cfg tokio_unstable diff --git a/Cargo.toml b/Cargo.toml index b29e9ce49a..94c851721f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,7 @@ tag-namespace = "libsql-server" # Whether to install an updater program install-updater = false # additional setup steps -github-build-setup = "libsql-server-release-build-setup.yml" +github-build-setup = "../templates/libsql-server-release-build-setup.yml" # Path that installers should place binaries in install-path = "CARGO_HOME" From 076c9b11ff1972775ca2d4949d7745e1aab13513 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:18:32 +0400 Subject: [PATCH 118/121] cargo dist init --- .github/workflows/libsql-server-release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index 56a2dbf0a0..cb085decac 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -116,6 +116,7 @@ jobs: with: submodules: recursive - name: "Prepare env vars" + run: "echo setup env vars" env: "RUSTFLAGS": "--cfg tokio_unstable" - name: Install cargo-dist From bf16fd21bd814ccc0755d3448f6a9d47fb685010 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:22:57 +0400 Subject: [PATCH 119/121] temporary enable build step in pr-run-mode --- .github/workflows/libsql-server-release.yml | 4 ++++ Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index cb085decac..d256a6a8e2 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -119,6 +119,10 @@ jobs: run: "echo setup env vars" env: "RUSTFLAGS": "--cfg tokio_unstable" + - uses: swatinem/rust-cache@v2 + with: + key: ${{ join(matrix.targets, '-') }} + cache-provider: ${{ matrix.cache_provider }} - name: Install cargo-dist run: ${{ matrix.install_dist }} # Get the dist-manifest diff --git a/Cargo.toml b/Cargo.toml index 94c851721f..5d61c2694a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ publish-jobs = ["homebrew"] # Whether cargo-dist should create a Github Release or use an existing draft create-release = true # Which actions to run on pull requests -pr-run-mode = "plan" +pr-run-mode = "upload" # A prefix git tags must include for cargo-dist to care about them tag-namespace = "libsql-server" # Whether to install an updater program From 346a02a840ed3ab9e136db9ec073b48952af9903 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:36:02 +0400 Subject: [PATCH 120/121] fix additional set-env build step --- .github/templates/libsql-server-release-build-setup.yml | 4 +--- .github/workflows/libsql-server-release.yml | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/templates/libsql-server-release-build-setup.yml b/.github/templates/libsql-server-release-build-setup.yml index 8c3b8a57bb..997e70066d 100644 --- a/.github/templates/libsql-server-release-build-setup.yml +++ b/.github/templates/libsql-server-release-build-setup.yml @@ -1,4 +1,2 @@ - name: Prepare env vars - run: "echo setup env vars" - env: - RUSTFLAGS: --cfg tokio_unstable + run: echo "RUSTFLAGS=--cfg tokio_unstable" >> $GITHUB_ENV diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index d256a6a8e2..683897f041 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -116,9 +116,7 @@ jobs: with: submodules: recursive - name: "Prepare env vars" - run: "echo setup env vars" - env: - "RUSTFLAGS": "--cfg tokio_unstable" + run: "echo \"RUSTFLAGS=--cfg tokio_unstable\" >> $GITHUB_ENV" - uses: swatinem/rust-cache@v2 with: key: ${{ join(matrix.targets, '-') }} From 355440640b865c1fccaf3434362c723792750dd5 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 18 Aug 2024 02:55:45 +0400 Subject: [PATCH 121/121] return back pr-mode to plan --- .github/workflows/libsql-server-release.yml | 4 ---- Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/libsql-server-release.yml b/.github/workflows/libsql-server-release.yml index 683897f041..db15bfc641 100644 --- a/.github/workflows/libsql-server-release.yml +++ b/.github/workflows/libsql-server-release.yml @@ -117,10 +117,6 @@ jobs: submodules: recursive - name: "Prepare env vars" run: "echo \"RUSTFLAGS=--cfg tokio_unstable\" >> $GITHUB_ENV" - - uses: swatinem/rust-cache@v2 - with: - key: ${{ join(matrix.targets, '-') }} - cache-provider: ${{ matrix.cache_provider }} - name: Install cargo-dist run: ${{ matrix.install_dist }} # Get the dist-manifest diff --git a/Cargo.toml b/Cargo.toml index 5d61c2694a..94c851721f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ publish-jobs = ["homebrew"] # Whether cargo-dist should create a Github Release or use an existing draft create-release = true # Which actions to run on pull requests -pr-run-mode = "upload" +pr-run-mode = "plan" # A prefix git tags must include for cargo-dist to care about them tag-namespace = "libsql-server" # Whether to install an updater program