From f510359793bd24d5427ef67ccaa5ab0caf6a95f2 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Thu, 1 Aug 2024 11:47:24 +0200 Subject: [PATCH 01/70] Breaking change: Make Rows and Row API more consistent. A few notes: I went the path of least resistance also assuming it would break fewer folks, i.e. make Row more like Rows and thus usize -> i32. Arguably, an unsigned type might be more appropriate both for length and indexes. I understand that the i32 stems from the sqlite bindings, which returns/accepts c_ints. Yet, Statement and Row made the jump to an unsigned, probably drawing the same conclusion, whereas Rows preserved its proximity to the c-bindings. This is probably an artifact? Also going on a limb, mapping c_int -> i32 is already a non-portable choice, with precision for c_int being platform dependent. --- libsql/src/de.rs | 2 +- libsql/src/hrana/mod.rs | 29 ++++++++++++++++++---------- libsql/src/local/impls.rs | 18 ++++++++++------- libsql/src/local/rows.rs | 18 ++++++++++------- libsql/src/local/statement.rs | 21 ++++++++++---------- libsql/src/replication/connection.rs | 18 ++++++++++------- libsql/src/rows.rs | 21 +++++++++----------- 7 files changed, 72 insertions(+), 55 deletions(-) diff --git a/libsql/src/de.rs b/libsql/src/de.rs index 63ee71f598..44f231c134 100644 --- a/libsql/src/de.rs +++ b/libsql/src/de.rs @@ -68,7 +68,7 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> { visitor.visit_map(RowMapAccess { row: self.row, - idx: 0..self.row.inner.column_count(), + idx: 0..(self.row.inner.column_count() as usize), value: None, }) } diff --git a/libsql/src/hrana/mod.rs b/libsql/src/hrana/mod.rs index 9befe549de..4a6fd0c63a 100644 --- a/libsql/src/hrana/mod.rs +++ b/libsql/src/hrana/mod.rs @@ -24,7 +24,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::rows::{RowInner, RowsInner}; +use super::rows::{ColumnsInner, RowInner, RowsInner}; pub(crate) type Result = std::result::Result; @@ -261,7 +261,12 @@ where async fn next(&mut self) -> crate::Result> { self.next().await } +} +impl ColumnsInner for HranaRows +where + S: Stream> + Send + Sync + Unpin, +{ fn column_count(&self) -> i32 { self.column_count() } @@ -303,13 +308,6 @@ impl RowInner for Row { Ok(into_value2(v)) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.cols - .get(idx as usize) - .and_then(|c| c.name.as_ref()) - .map(|s| s.as_str()) - } - fn column_str(&self, idx: i32) -> crate::Result<&str> { if let Some(value) = self.inner.get(idx as usize) { if let proto::Value::Text { value } = value { @@ -321,6 +319,15 @@ impl RowInner for Row { Err(crate::Error::ColumnNotFound(idx)) } } +} + +impl ColumnsInner for Row { + fn column_name(&self, idx: i32) -> Option<&str> { + self.cols + .get(idx as usize) + .and_then(|c| c.name.as_ref()) + .map(|s| s.as_str()) + } fn column_type(&self, idx: i32) -> crate::Result { if let Some(value) = self.inner.get(idx as usize) { @@ -337,8 +344,8 @@ impl RowInner for Row { } } - fn column_count(&self) -> usize { - self.cols.len() + fn column_count(&self) -> i32 { + self.cols.len() as i32 } } @@ -417,7 +424,9 @@ impl RowsInner for StmtResultRows { inner: Box::new(row), })) } +} +impl ColumnsInner for StmtResultRows { fn column_count(&self) -> i32 { self.cols.len() as i32 } diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 8a9a5f440e..2338317a34 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -5,7 +5,7 @@ use crate::connection::BatchRows; use crate::{ connection::Conn, params::Params, - rows::{RowInner, RowsInner}, + rows::{ColumnsInner, RowInner, RowsInner}, statement::Stmt, transaction::Tx, Column, Connection, Result, Row, Rows, Statement, Transaction, TransactionBehavior, Value, @@ -159,7 +159,9 @@ impl RowsInner for LibsqlRows { Ok(row) } +} +impl ColumnsInner for LibsqlRows { fn column_count(&self) -> i32 { self.0.column_count() } @@ -180,20 +182,22 @@ impl RowInner for LibsqlRow { self.0.get_value(idx) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.0.column_name(idx) - } - fn column_str(&self, idx: i32) -> Result<&str> { self.0.get::<&str>(idx) } +} + +impl ColumnsInner for LibsqlRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.0.column_name(idx) + } fn column_type(&self, idx: i32) -> Result { self.0.column_type(idx).map(ValueType::from) } - fn column_count(&self) -> usize { - self.0.stmt.column_count() + fn column_count(&self) -> i32 { + self.0.stmt.column_count() as i32 } } diff --git a/libsql/src/local/rows.rs b/libsql/src/local/rows.rs index 7eb52d461b..4d4e622c75 100644 --- a/libsql/src/local/rows.rs +++ b/libsql/src/local/rows.rs @@ -1,6 +1,6 @@ use crate::local::{Connection, Statement}; use crate::params::Params; -use crate::rows::{RowInner, RowsInner}; +use crate::rows::{ColumnsInner, RowInner, RowsInner}; use crate::{errors, Error, Result}; use crate::{Value, ValueRef}; use libsql_sys::ValueType; @@ -213,7 +213,9 @@ impl RowsInner for BatchedRows { Ok(None) } } +} +impl ColumnsInner for BatchedRows { fn column_count(&self) -> i32 { self.cols.len() as i32 } @@ -244,10 +246,6 @@ impl RowInner for BatchedRow { .ok_or(Error::InvalidColumnIndex) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.cols.get(idx as usize).map(|c| c.0.as_str()) - } - fn column_str(&self, idx: i32) -> Result<&str> { self.row .get(idx as usize) @@ -258,9 +256,15 @@ impl RowInner for BatchedRow { .ok_or(Error::InvalidColumnType) }) } +} + +impl ColumnsInner for BatchedRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.cols.get(idx as usize).map(|c| c.0.as_str()) + } - fn column_count(&self) -> usize { - self.cols.len() + fn column_count(&self) -> i32 { + self.cols.len() as i32 } fn column_type(&self, idx: i32) -> Result { diff --git a/libsql/src/local/statement.rs b/libsql/src/local/statement.rs index 70116a152e..c28a66f18f 100644 --- a/libsql/src/local/statement.rs +++ b/libsql/src/local/statement.rs @@ -250,15 +250,15 @@ impl Statement { /// sure that current statement has already been stepped once before /// calling this method. pub fn column_names(&self) -> Vec<&str> { - let n = self.column_count(); - let mut cols = Vec::with_capacity(n); - for i in 0..n { - let s = self.column_name(i); - if let Some(s) = s { - cols.push(s); - } - } - cols + let n = self.column_count(); + let mut cols = Vec::with_capacity(n); + for i in 0..n { + let s = self.column_name(i); + if let Some(s) = s { + cols.push(s); + } + } + cols } /// Return the number of columns in the result set returned by the prepared @@ -314,12 +314,11 @@ impl Statement { /// the specified `name`. pub fn column_index(&self, name: &str) -> Result { let bytes = name.as_bytes(); - let n = self.column_count() as i32; + let n = self.column_count(); for i in 0..n { // Note: `column_name` is only fallible if `i` is out of bounds, // which we've already checked. let col_name = self - .inner .column_name(i) .ok_or_else(|| Error::InvalidColumnName(name.to_string()))?; if bytes.eq_ignore_ascii_case(col_name.as_bytes()) { diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index c82f523559..c720838798 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -11,7 +11,7 @@ use parking_lot::Mutex; use crate::parser; use crate::parser::StmtKind; -use crate::rows::{RowInner, RowsInner}; +use crate::rows::{ColumnsInner, RowInner, RowsInner}; use crate::statement::Stmt; use crate::transaction::Tx; use crate::{ @@ -780,7 +780,9 @@ impl RowsInner for RemoteRows { let row = RemoteRow(values, self.0.column_descriptions.clone()); Ok(Some(row).map(Box::new).map(|inner| Row { inner })) } +} +impl ColumnsInner for RemoteRows { fn column_count(&self) -> i32 { self.0.column_descriptions.len() as i32 } @@ -813,10 +815,6 @@ impl RowInner for RemoteRow { .ok_or(Error::InvalidColumnIndex) } - fn column_name(&self, idx: i32) -> Option<&str> { - self.1.get(idx as usize).map(|s| s.name.as_str()) - } - fn column_str(&self, idx: i32) -> Result<&str> { let value = self.0.get(idx as usize).ok_or(Error::InvalidColumnIndex)?; @@ -825,6 +823,12 @@ impl RowInner for RemoteRow { _ => Err(Error::InvalidColumnType), } } +} + +impl ColumnsInner for RemoteRow { + fn column_name(&self, idx: i32) -> Option<&str> { + self.1.get(idx as usize).map(|s| s.name.as_str()) + } fn column_type(&self, idx: i32) -> Result { let col = self.1.get(idx as usize).unwrap(); @@ -835,8 +839,8 @@ impl RowInner for RemoteRow { .ok_or(Error::InvalidColumnType) } - fn column_count(&self) -> usize { - self.1.len() + fn column_count(&self) -> i32 { + self.1.len() as i32 } } diff --git a/libsql/src/rows.rs b/libsql/src/rows.rs index b97aeac203..a10d82b827 100644 --- a/libsql/src/rows.rs +++ b/libsql/src/rows.rs @@ -38,14 +38,8 @@ impl Column<'_> { } #[async_trait::async_trait] -pub(crate) trait RowsInner { +pub(crate) trait RowsInner: ColumnsInner { async fn next(&mut self) -> Result>; - - fn column_count(&self) -> i32; - - fn column_name(&self, idx: i32) -> Option<&str>; - - fn column_type(&self, idx: i32) -> Result; } /// A set of rows returned from a connection. @@ -131,7 +125,7 @@ impl Row { } /// Get the count of columns in this set of rows. - pub fn column_count(&self) -> usize { + pub fn column_count(&self) -> i32 { self.inner.column_count() } @@ -284,12 +278,15 @@ where } impl Sealed for Option {} -pub(crate) trait RowInner: fmt::Debug { - fn column_value(&self, idx: i32) -> Result; - fn column_str(&self, idx: i32) -> Result<&str>; +pub(crate) trait ColumnsInner { fn column_name(&self, idx: i32) -> Option<&str>; fn column_type(&self, idx: i32) -> Result; - fn column_count(&self) -> usize; + fn column_count(&self) -> i32; +} + +pub(crate) trait RowInner: ColumnsInner + fmt::Debug { + fn column_value(&self, idx: i32) -> Result; + fn column_str(&self, idx: i32) -> Result<&str>; } mod sealed { From a68f042914cd637904123c7375e297a00ac5cecb Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 2 Aug 2024 09:26:28 -0700 Subject: [PATCH 02/70] libsql: release v0.5.0 --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- libsql-ffi/Cargo.toml | 2 +- libsql-replication/Cargo.toml | 4 ++-- libsql-sys/Cargo.toml | 4 ++-- libsql/Cargo.toml | 8 ++++---- vendored/rusqlite/Cargo.toml | 4 ++-- vendored/sqlite3-parser/Cargo.toml | 2 +- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 04ae728065..17cfc0e090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3430,7 +3430,7 @@ dependencies = [ [[package]] name = "libsql" -version = "0.5.0-alpha.2" +version = "0.5.0" dependencies = [ "anyhow", "async-stream", @@ -3488,7 +3488,7 @@ dependencies = [ [[package]] name = "libsql-ffi" -version = "0.3.0" +version = "0.4.0" dependencies = [ "bindgen 0.66.1", "cc", @@ -3508,7 +3508,7 @@ dependencies = [ [[package]] name = "libsql-rusqlite" -version = "0.31.0" +version = "0.32.0" dependencies = [ "bencher", "bitflags 2.6.0", @@ -3631,7 +3631,7 @@ dependencies = [ [[package]] name = "libsql-sqlite3-parser" -version = "0.12.0" +version = "0.13.0" dependencies = [ "bitflags 2.6.0", "cc", @@ -3687,7 +3687,7 @@ dependencies = [ [[package]] name = "libsql-sys" -version = "0.6.0" +version = "0.7.0" dependencies = [ "bytes", "libsql-ffi", @@ -3768,7 +3768,7 @@ dependencies = [ [[package]] name = "libsql_replication" -version = "0.4.0" +version = "0.5.0" dependencies = [ "aes", "arbitrary", diff --git a/Cargo.toml b/Cargo.toml index 92487ecdd0..9381fb83f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ codegen-units = 1 panic = "unwind" [workspace.dependencies] -rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = "0.31", default-features = false, features = [ +rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = "0.32", default-features = false, features = [ "libsql-experimental", "column_decltype", "load_extension", diff --git a/libsql-ffi/Cargo.toml b/libsql-ffi/Cargo.toml index 9b5cbced11..ef9ade1726 100644 --- a/libsql-ffi/Cargo.toml +++ b/libsql-ffi/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-ffi" -version = "0.3.0" +version = "0.4.0" edition = "2021" build = "build.rs" license = "MIT" diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 56f00d7a7d..d2a9431cba 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql_replication" -version = "0.4.0" +version = "0.5.0" edition = "2021" description = "libSQL replication protocol" repository = "https://github.com/tursodatabase/libsql" @@ -11,7 +11,7 @@ license = "MIT" [dependencies] tonic = { version = "0.11", features = ["tls"] } prost = "0.12" -libsql-sys = { version = "0.6", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } +libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } diff --git a/libsql-sys/Cargo.toml b/libsql-sys/Cargo.toml index 26dd091ea9..8351012d9f 100644 --- a/libsql-sys/Cargo.toml +++ b/libsql-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-sys" -version = "0.6.0" +version = "0.7.0" edition = "2021" license = "MIT" description = "Native bindings to libSQL" @@ -12,7 +12,7 @@ categories = ["external-ffi-bindings"] [dependencies] bytes = "1.5.0" -libsql-ffi = { version = "0.3", path = "../libsql-ffi/" } +libsql-ffi = { version = "0.4", path = "../libsql-ffi/" } once_cell = "1.18.0" rusqlite = { workspace = true, features = ["trace"], optional = true } tracing = "0.1.37" diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index fa89cc68ad..efae2abea3 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql" -version = "0.5.0-alpha.2" +version = "0.5.0" edition = "2021" description = "libSQL library: the main gateway for interacting with the database" repository = "https://github.com/tursodatabase/libsql" @@ -11,7 +11,7 @@ tracing = { version = "0.1.37", default-features = false } thiserror = "1.0.40" futures = { version = "0.3.28", optional = true } -libsql-sys = { version = "0.6", path = "../libsql-sys", optional = true } +libsql-sys = { version = "0.7", path = "../libsql-sys", optional = true } libsql-hrana = { version = "0.2", path = "../libsql-hrana", optional = true } tokio = { version = "1.29.1", features = ["sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } @@ -37,10 +37,10 @@ tower-http = { version = "0.4.4", features = ["trace", "set-header", "util"], op http = { version = "0.2", optional = true } zerocopy = { version = "0.7.28", optional = true } -sqlite3-parser = { package = "libsql-sqlite3-parser", path = "../vendored/sqlite3-parser", version = "0.12", optional = true } +sqlite3-parser = { package = "libsql-sqlite3-parser", path = "../vendored/sqlite3-parser", version = "0.13", optional = true } fallible-iterator = { version = "0.3", optional = true } -libsql_replication = { version = "0.4", path = "../libsql-replication", optional = true } +libsql_replication = { version = "0.5", path = "../libsql-replication", optional = true } async-stream = { version = "0.3.5", optional = true } [dev-dependencies] diff --git a/vendored/rusqlite/Cargo.toml b/vendored/rusqlite/Cargo.toml index 2d332f3279..d9fbcc525e 100644 --- a/vendored/rusqlite/Cargo.toml +++ b/vendored/rusqlite/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "libsql-rusqlite" # Note: Update version in README.md when you change this. -version = "0.31.0" +version = "0.32.0" authors = ["The rusqlite developers"] edition = "2018" description = "Ergonomic wrapper for SQLite (libsql fork)" @@ -109,7 +109,7 @@ fallible-iterator = "0.2" fallible-streaming-iterator = "0.1" uuid = { version = "1.0", optional = true } smallvec = "1.6.1" -libsql-ffi = { version = "0.3", path = "../../libsql-ffi" } +libsql-ffi = { version = "0.4", path = "../../libsql-ffi" } [dev-dependencies] doc-comment = "0.3" diff --git a/vendored/sqlite3-parser/Cargo.toml b/vendored/sqlite3-parser/Cargo.toml index 5ed9e31f4d..0381ac1d99 100644 --- a/vendored/sqlite3-parser/Cargo.toml +++ b/vendored/sqlite3-parser/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-sqlite3-parser" -version = "0.12.0" +version = "0.13.0" edition = "2021" authors = ["gwenn"] description = "SQL parser (as understood by SQLite) (libsql fork)" From 0917f84cde2d272141eaa79d933ed0736fb668e5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 2 Aug 2024 23:54:55 +0200 Subject: [PATCH 03/70] introduce NamespaceConfigurator --- .../src/namespace/configurator/mod.rs | 56 ++++++ .../src/namespace/configurator/primary.rs | 128 ++++++++++++ .../src/namespace/configurator/replica.rs | 190 ++++++++++++++++++ libsql-server/src/namespace/mod.rs | 17 +- 4 files changed, 383 insertions(+), 8 deletions(-) create mode 100644 libsql-server/src/namespace/configurator/mod.rs create mode 100644 libsql-server/src/namespace/configurator/primary.rs create mode 100644 libsql-server/src/namespace/configurator/replica.rs diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs new file mode 100644 index 0000000000..0caa1de149 --- /dev/null +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -0,0 +1,56 @@ +use std::pin::Pin; + +use futures::Future; + +use super::broadcasters::BroadcasterHandle; +use super::meta_store::MetaStoreHandle; +use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; + +mod replica; +mod primary; + +type DynConfigurator = Box; + +#[derive(Default)] +struct NamespaceConfigurators { + replica_configurator: Option, + primary_configurator: Option, + schema_configurator: Option, +} + +impl NamespaceConfigurators { + pub fn with_primary( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { + self.primary_configurator = Some(Box::new(c)); + self + } + + pub fn with_replica( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { + self.replica_configurator = Some(Box::new(c)); + self + } + + pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + self.schema_configurator = Some(Box::new(c)); + self + } +} + +pub trait ConfigureNamespace { + fn setup<'a>( + &'a self, + ns_config: &'a NamespaceConfig, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>>; +} diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs new file mode 100644 index 0000000000..f28d288a97 --- /dev/null +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -0,0 +1,128 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::{path::Path, pin::Pin, sync::Arc}; + +use futures::prelude::Future; +use tokio::task::JoinSet; + +use crate::connection::MakeConnection; +use crate::database::{Database, PrimaryDatabase}; +use crate::namespace::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::run_periodic_checkpoint; +use crate::schema::{has_pending_migration_task, setup_migration_table}; + +use super::ConfigureNamespace; + +pub struct PrimaryConfigurator; + +impl ConfigureNamespace for PrimaryConfigurator { + fn setup<'a>( + &'a self, + config: &'a NamespaceConfig, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + _reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + _store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> + { + Box::pin(async move { + let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); + let fresh_namespace = !db_path.try_exists()?; + // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications + match try_new_primary( + config, + name.clone(), + meta_store_handle, + restore_option, + resolve_attach_path, + db_path.clone(), + broadcaster, + ) + .await + { + Ok(this) => Ok(this), + Err(e) if fresh_namespace => { + tracing::error!("an error occured while deleting creating namespace, cleaning..."); + if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { + tracing::error!("failed to remove dirty namespace directory: {e}") + } + Err(e) + } + Err(e) => Err(e), + } + }) + } +} + +#[tracing::instrument(skip_all, fields(namespace))] +async fn try_new_primary( + ns_config: &NamespaceConfig, + namespace: NamespaceName, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + resolve_attach_path: ResolveNamespacePathFn, + db_path: Arc, + broadcaster: BroadcasterHandle, +) -> crate::Result { + let mut join_set = JoinSet::new(); + + tokio::fs::create_dir_all(&db_path).await?; + + let block_writes = Arc::new(AtomicBool::new(false)); + let (connection_maker, wal_wrapper, stats) = Namespace::make_primary_connection_maker( + ns_config, + &meta_store_handle, + &db_path, + &namespace, + restore_option, + block_writes.clone(), + &mut join_set, + resolve_attach_path, + broadcaster, + ) + .await?; + let connection_maker = Arc::new(connection_maker); + + if meta_store_handle.get().shared_schema_name.is_some() { + let block_writes = block_writes.clone(); + let conn = connection_maker.create().await?; + tokio::task::spawn_blocking(move || { + conn.with_raw(|conn| -> crate::Result<()> { + setup_migration_table(conn)?; + if has_pending_migration_task(conn)? { + block_writes.store(true, Ordering::SeqCst); + } + Ok(()) + }) + }) + .await + .unwrap()?; + } + + if let Some(checkpoint_interval) = ns_config.checkpoint_interval { + join_set.spawn(run_periodic_checkpoint( + connection_maker.clone(), + checkpoint_interval, + namespace.clone(), + )); + } + + tracing::debug!("Done making new primary"); + + Ok(Namespace { + tasks: join_set, + db: Database::Primary(PrimaryDatabase { + wal_wrapper, + connection_maker, + block_writes, + }), + name: namespace, + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) +} diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs new file mode 100644 index 0000000000..4d3ca1dadf --- /dev/null +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -0,0 +1,190 @@ +use std::pin::Pin; +use std::sync::Arc; + +use futures::Future; +use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; +use tokio::task::JoinSet; + +use crate::connection::write_proxy::MakeWriteProxyConn; +use crate::connection::MakeConnection; +use crate::database::{Database, ReplicaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{Namespace, RestoreOption}; +use crate::namespace::{ + make_stats, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResetOp, + ResolveNamespacePathFn, +}; +use crate::{DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; + +use super::ConfigureNamespace; + +pub struct ReplicaConfigurator; + +impl ConfigureNamespace for ReplicaConfigurator { + fn setup<'a>( + &'a self, + config: &'a NamespaceConfig, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> + { + Box::pin(async move { + tracing::debug!("creating replica namespace"); + let db_path = config.base_path.join("dbs").join(name.as_str()); + let channel = config.channel.clone().expect("bad replica config"); + let uri = config.uri.clone().expect("bad replica config"); + + let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + let client = crate::replication::replicator_client::Client::new( + name.clone(), + rpc_client, + &db_path, + meta_store_handle.clone(), + store.clone(), + ) + .await?; + let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); + let mut replicator = libsql_replication::replicator::Replicator::new( + client, + db_path.join("data"), + DEFAULT_AUTO_CHECKPOINT, + config.encryption_config.clone(), + ) + .await?; + + tracing::debug!("try perform handshake"); + // force a handshake now, to retrieve the primary's current replication index + match replicator.try_perform_handshake().await { + Err(libsql_replication::replicator::Error::Meta( + libsql_replication::meta::Error::LogIncompatible, + )) => { + tracing::error!( + "trying to replicate incompatible logs, reseting replica and nuking db dir" + ); + std::fs::remove_dir_all(&db_path).unwrap(); + return self.setup( + config, + meta_store_handle, + restore_option, + name, + reset, + resolve_attach_path, + store, + broadcaster, + ) + .await; + } + Err(e) => Err(e)?, + Ok(_) => (), + } + + tracing::debug!("done performing handshake"); + + let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; + + let mut join_set = JoinSet::new(); + let namespace = name.clone(); + join_set.spawn(async move { + use libsql_replication::replicator::Error; + loop { + match replicator.run().await { + err @ Error::Fatal(_) => Err(err)?, + err @ Error::NamespaceDoesntExist => { + tracing::error!("namespace {namespace} doesn't exist, destroying..."); + (reset)(ResetOp::Destroy(namespace.clone())); + Err(err)?; + } + e @ Error::Injector(_) => { + tracing::error!("potential corruption detected while replicating, reseting replica: {e}"); + (reset)(ResetOp::Reset(namespace.clone())); + Err(e)?; + }, + Error::Meta(err) => { + use libsql_replication::meta::Error; + match err { + Error::LogIncompatible => { + tracing::error!("trying to replicate incompatible logs, reseting replica"); + (reset)(ResetOp::Reset(namespace.clone())); + Err(err)?; + } + Error::InvalidMetaFile + | Error::Io(_) + | Error::InvalidLogId + | Error::FailedToCommit(_) + | Error::InvalidReplicationPath + | Error::RequiresCleanDatabase => { + // We retry from last frame index? + tracing::warn!("non-fatal replication error, retrying from last commit index: {err}"); + }, + } + } + e @ (Error::Internal(_) + | Error::Client(_) + | Error::PrimaryHandshakeTimeout + | Error::NeedSnapshot) => { + tracing::warn!("non-fatal replication error, retrying from last commit index: {e}"); + }, + Error::NoHandshake => { + // not strictly necessary, but in case the handshake error goes uncaught, + // we reset the client state. + replicator.client_mut().reset_token(); + } + Error::SnapshotPending => unreachable!(), + } + } + }); + + let stats = make_stats( + &db_path, + &mut join_set, + meta_store_handle.clone(), + config.stats_sender.clone(), + name.clone(), + applied_frame_no_receiver.clone(), + config.encryption_config.clone(), + ) + .await?; + + let connection_maker = MakeWriteProxyConn::new( + db_path.clone(), + config.extensions.clone(), + channel.clone(), + uri.clone(), + stats.clone(), + broadcaster, + meta_store_handle.clone(), + applied_frame_no_receiver, + config.max_response_size, + config.max_total_response_size, + primary_current_replicatio_index, + config.encryption_config.clone(), + resolve_attach_path, + config.make_wal_manager.clone(), + ) + .await? + .throttled( + config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + config.max_total_response_size, + config.max_concurrent_requests, + ); + + Ok(Namespace { + tasks: join_set, + db: Database::Replica(ReplicaDatabase { + connection_maker: Arc::new(connection_maker), + }), + name: name.clone(), + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) + }) + } +} diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 6e48e7f1d8..6a04b11fb8 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,11 +1,3 @@ -pub mod broadcasters; -mod fork; -pub mod meta_store; -mod name; -pub mod replication_wal; -mod schema_lock; -mod store; - use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Weak}; @@ -57,6 +49,15 @@ pub use self::name::NamespaceName; use self::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; pub use self::store::NamespaceStore; +pub mod broadcasters; +mod fork; +pub mod meta_store; +mod name; +pub mod replication_wal; +mod schema_lock; +mod store; +mod configurator; + pub type ResetCb = Box; pub type ResolveNamespacePathFn = Arc crate::Result> + Sync + Send + 'static>; From f9daa9e08f58efdebd52acb4d45435266bec34af Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 00:00:45 +0200 Subject: [PATCH 04/70] add configurators to namespace store --- libsql-server/src/namespace/configurator/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 0caa1de149..a692f75652 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -12,7 +12,7 @@ mod primary; type DynConfigurator = Box; #[derive(Default)] -struct NamespaceConfigurators { +pub(crate) struct NamespaceConfigurators { replica_configurator: Option, primary_configurator: Option, schema_configurator: Option, diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index e0147fc2e8..984a520154 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -19,6 +19,7 @@ use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, Nam use crate::stats::Stats; use super::broadcasters::{BroadcasterHandle, BroadcasterRegistry}; +use super::configurator::NamespaceConfigurators; use super::meta_store::{MetaStore, MetaStoreHandle}; use super::schema_lock::SchemaLocksRegistry; use super::{Namespace, NamespaceConfig, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; @@ -47,6 +48,7 @@ pub struct NamespaceStoreInner { pub config: NamespaceConfig, schema_locks: SchemaLocksRegistry, broadcasters: BroadcasterRegistry, + configurators: NamespaceConfigurators, } impl NamespaceStore { @@ -90,6 +92,7 @@ impl NamespaceStore { config, schema_locks: Default::default(), broadcasters: Default::default(), + configurators: NamespaceConfigurators::default(), }), }) } From 8b377a6e06dfc051bedb68976577adb8de339f40 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 21:18:51 +0200 Subject: [PATCH 05/70] add shcema configurator --- .../src/namespace/configurator/schema.rs | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 libsql-server/src/namespace/configurator/schema.rs diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs new file mode 100644 index 0000000000..864b75239f --- /dev/null +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -0,0 +1,65 @@ +use std::sync::{atomic::AtomicBool, Arc}; + +use futures::prelude::Future; +use tokio::task::JoinSet; + +use crate::database::{Database, SchemaDatabase}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceConfig, NamespaceName, NamespaceStore, + ResetCb, ResolveNamespacePathFn, RestoreOption, +}; +use crate::namespace::broadcasters::BroadcasterHandle; + +use super::ConfigureNamespace; + +pub struct SchemaConfigurator; + +impl ConfigureNamespace for SchemaConfigurator { + fn setup<'a>( + &'a self, + ns_config: &'a NamespaceConfig, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + _reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + _store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let mut join_set = JoinSet::new(); + let db_path = ns_config.base_path.join("dbs").join(name.as_str()); + + tokio::fs::create_dir_all(&db_path).await?; + + let (connection_maker, wal_manager, stats) = Namespace::make_primary_connection_maker( + ns_config, + &db_config, + &db_path, + &name, + restore_option, + Arc::new(AtomicBool::new(false)), // this is always false for schema + &mut join_set, + resolve_attach_path, + broadcaster, + ) + .await?; + + Ok(Namespace { + db: Database::Schema(SchemaDatabase::new( + ns_config.migration_scheduler.clone(), + name.clone(), + connection_maker, + wal_manager, + db_config.clone(), + )), + name: name.clone(), + tasks: join_set, + stats, + db_config_store: db_config.clone(), + path: db_path.into(), + }) + }) + } +} From 978dd7147d0304397f261d918c70c2806eff82a5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 3 Aug 2024 21:19:00 +0200 Subject: [PATCH 06/70] instanciate namesapces from configurators --- .../src/namespace/configurator/mod.rs | 25 +- libsql-server/src/namespace/fork.rs | 26 +- libsql-server/src/namespace/mod.rs | 374 +----------------- libsql-server/src/namespace/store.rs | 76 ++-- 4 files changed, 74 insertions(+), 427 deletions(-) diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index a692f75652..d3cd390b34 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -8,14 +8,19 @@ use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveName mod replica; mod primary; +mod schema; -type DynConfigurator = Box; +pub use replica::ReplicaConfigurator; +pub use primary::PrimaryConfigurator; +pub use schema::SchemaConfigurator; + +type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; #[derive(Default)] pub(crate) struct NamespaceConfigurators { - replica_configurator: Option, - primary_configurator: Option, - schema_configurator: Option, + replica_configurator: Option>, + primary_configurator: Option>, + schema_configurator: Option>, } impl NamespaceConfigurators { @@ -39,6 +44,18 @@ impl NamespaceConfigurators { self.schema_configurator = Some(Box::new(c)); self } + + pub fn configure_schema(&self) -> crate::Result<&DynConfigurator> { + self.schema_configurator.as_deref().ok_or_else(|| todo!()) + } + + pub fn configure_primary(&self) -> crate::Result<&DynConfigurator> { + self.primary_configurator.as_deref().ok_or_else(|| todo!()) + } + + pub fn configure_replica(&self) -> crate::Result<&DynConfigurator> { + self.replica_configurator.as_deref().ok_or_else(|| todo!()) + } } pub trait ConfigureNamespace { diff --git a/libsql-server/src/namespace/fork.rs b/libsql-server/src/namespace/fork.rs index dfa053b43d..f25bf7a9a9 100644 --- a/libsql-server/src/namespace/fork.rs +++ b/libsql-server/src/namespace/fork.rs @@ -12,14 +12,12 @@ use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::time::Duration; use tokio_stream::StreamExt; -use crate::namespace::ResolveNamespacePathFn; use crate::replication::primary::frame_stream::FrameStream; use crate::replication::{LogReadError, ReplicationLogger}; use crate::{BLOCKING_RT, LIBSQL_PAGE_SIZE}; -use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, RestoreOption}; +use super::{NamespaceName, NamespaceStore, RestoreOption}; type Result = crate::Result; @@ -54,16 +52,13 @@ async fn write_frame(frame: &FrameBorrowed, temp_file: &mut tokio::fs::File) -> Ok(()) } -pub struct ForkTask<'a> { +pub struct ForkTask { pub base_path: Arc, pub logger: Arc, pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub ns_config: &'a NamespaceConfig, - pub resolve_attach: ResolveNamespacePathFn, pub store: NamespaceStore, - pub broadcaster: BroadcasterHandle, } pub struct PointInTimeRestore { @@ -71,7 +66,7 @@ pub struct PointInTimeRestore { pub replicator_options: bottomless::replicator::Options, } -impl<'a> ForkTask<'a> { +impl ForkTask { pub async fn fork(self) -> Result { let base_path = self.base_path.clone(); let dest_namespace = self.to_namespace.clone(); @@ -105,18 +100,9 @@ impl<'a> ForkTask<'a> { let dest_path = self.base_path.join("dbs").join(self.to_namespace.as_str()); tokio::fs::rename(temp_dir.path(), dest_path).await?; - Namespace::from_config( - self.ns_config, - self.to_config.clone(), - RestoreOption::Latest, - &self.to_namespace, - Box::new(|_op| {}), - self.resolve_attach.clone(), - self.store.clone(), - self.broadcaster, - ) - .await - .map_err(|e| ForkError::CreateNamespace(Box::new(e))) + self.store.make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) + .await + .map_err(|e| ForkError::CreateNamespace(Box::new(e))) } /// Restores the database state from a local log file. diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 6a04b11fb8..41bb3ab9cc 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,5 +1,5 @@ use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Weak}; use anyhow::{Context as _, Error}; @@ -10,7 +10,6 @@ use chrono::NaiveDateTime; use enclose::enclose; use futures_core::{Future, Stream}; use hyper::Uri; -use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_sys::wal::Sqlite3WalManager; use libsql_sys::EncryptionConfig; use tokio::io::AsyncBufReadExt; @@ -25,20 +24,17 @@ use crate::auth::parse_jwt_keys; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::connection::libsql::{open_conn, MakeLibSqlConn}; -use crate::connection::write_proxy::MakeWriteProxyConn; -use crate::connection::Connection; -use crate::connection::MakeConnection; +use crate::connection::{Connection as _, MakeConnection}; use crate::database::{ - Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, PrimaryDatabase, - ReplicaDatabase, SchemaDatabase, + Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, }; use crate::error::LoadDumpError; use crate::replication::script_backup_manager::ScriptBackupManager; use crate::replication::{FrameNo, ReplicationLogger}; -use crate::schema::{has_pending_migration_task, setup_migration_table, SchedulerHandle}; +use crate::schema::SchedulerHandle; use crate::stats::Stats; use crate::{ - run_periodic_checkpoint, StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, + StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, }; pub use fork::ForkError; @@ -101,54 +97,6 @@ pub struct Namespace { } impl Namespace { - async fn from_config( - ns_config: &NamespaceConfig, - db_config: MetaStoreHandle, - restore_option: RestoreOption, - name: &NamespaceName, - reset: ResetCb, - resolve_attach_path: ResolveNamespacePathFn, - store: NamespaceStore, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - match ns_config.db_kind { - DatabaseKind::Primary if db_config.get().is_shared_schema => { - Self::new_schema( - ns_config, - name.clone(), - db_config, - restore_option, - resolve_attach_path, - broadcaster, - ) - .await - } - DatabaseKind::Primary => { - Self::new_primary( - ns_config, - name.clone(), - db_config, - restore_option, - resolve_attach_path, - broadcaster, - ) - .await - } - DatabaseKind::Replica => { - Self::new_replica( - ns_config, - name.clone(), - db_config, - reset, - resolve_attach_path, - store, - broadcaster, - ) - .await - } - } - } - pub(crate) fn name(&self) -> &NamespaceName { &self.name } @@ -248,40 +196,6 @@ impl Namespace { self.db_config_store.changed() } - async fn new_primary( - config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); - let fresh_namespace = !db_path.try_exists()?; - // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications - match Self::try_new_primary( - config, - name.clone(), - meta_store_handle, - restore_option, - resolve_attach_path, - db_path.clone(), - broadcaster, - ) - .await - { - Ok(this) => Ok(this), - Err(e) if fresh_namespace => { - tracing::error!("an error occured while deleting creating namespace, cleaning..."); - if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { - tracing::error!("failed to remove dirty namespace directory: {e}") - } - Err(e) - } - Err(e) => Err(e), - } - } - #[tracing::instrument(skip_all)] async fn make_primary_connection_maker( ns_config: &NamespaceConfig, @@ -417,237 +331,6 @@ impl Namespace { Ok((connection_maker, wal_wrapper, stats)) } - #[tracing::instrument(skip_all, fields(namespace))] - async fn try_new_primary( - ns_config: &NamespaceConfig, - namespace: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - db_path: Arc, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let mut join_set = JoinSet::new(); - - tokio::fs::create_dir_all(&db_path).await?; - - let block_writes = Arc::new(AtomicBool::new(false)); - let (connection_maker, wal_wrapper, stats) = Self::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &namespace, - restore_option, - block_writes.clone(), - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - let connection_maker = Arc::new(connection_maker); - - if meta_store_handle.get().shared_schema_name.is_some() { - let block_writes = block_writes.clone(); - let conn = connection_maker.create().await?; - tokio::task::spawn_blocking(move || { - conn.with_raw(|conn| -> crate::Result<()> { - setup_migration_table(conn)?; - if has_pending_migration_task(conn)? { - block_writes.store(true, Ordering::SeqCst); - } - Ok(()) - }) - }) - .await - .unwrap()?; - } - - if let Some(checkpoint_interval) = ns_config.checkpoint_interval { - join_set.spawn(run_periodic_checkpoint( - connection_maker.clone(), - checkpoint_interval, - namespace.clone(), - )); - } - - tracing::debug!("Done making new primary"); - - Ok(Self { - tasks: join_set, - db: Database::Primary(PrimaryDatabase { - wal_wrapper, - connection_maker, - block_writes, - }), - name: namespace, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } - - #[tracing::instrument(skip_all, fields(name))] - #[async_recursion::async_recursion] - async fn new_replica( - config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - reset: ResetCb, - resolve_attach_path: ResolveNamespacePathFn, - store: NamespaceStore, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - tracing::debug!("creating replica namespace"); - let db_path = config.base_path.join("dbs").join(name.as_str()); - let channel = config.channel.clone().expect("bad replica config"); - let uri = config.uri.clone().expect("bad replica config"); - - let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); - let client = crate::replication::replicator_client::Client::new( - name.clone(), - rpc_client, - &db_path, - meta_store_handle.clone(), - store.clone(), - ) - .await?; - let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); - let mut replicator = libsql_replication::replicator::Replicator::new( - client, - db_path.join("data"), - DEFAULT_AUTO_CHECKPOINT, - config.encryption_config.clone(), - ) - .await?; - - tracing::debug!("try perform handshake"); - // force a handshake now, to retrieve the primary's current replication index - match replicator.try_perform_handshake().await { - Err(libsql_replication::replicator::Error::Meta( - libsql_replication::meta::Error::LogIncompatible, - )) => { - tracing::error!( - "trying to replicate incompatible logs, reseting replica and nuking db dir" - ); - std::fs::remove_dir_all(&db_path).unwrap(); - return Self::new_replica( - config, - name, - meta_store_handle, - reset, - resolve_attach_path, - store, - broadcaster, - ) - .await; - } - Err(e) => Err(e)?, - Ok(_) => (), - } - - tracing::debug!("done performing handshake"); - - let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; - - let mut join_set = JoinSet::new(); - let namespace = name.clone(); - join_set.spawn(async move { - use libsql_replication::replicator::Error; - loop { - match replicator.run().await { - err @ Error::Fatal(_) => Err(err)?, - err @ Error::NamespaceDoesntExist => { - tracing::error!("namespace {namespace} doesn't exist, destroying..."); - (reset)(ResetOp::Destroy(namespace.clone())); - Err(err)?; - } - e @ Error::Injector(_) => { - tracing::error!("potential corruption detected while replicating, reseting replica: {e}"); - (reset)(ResetOp::Reset(namespace.clone())); - Err(e)?; - }, - Error::Meta(err) => { - use libsql_replication::meta::Error; - match err { - Error::LogIncompatible => { - tracing::error!("trying to replicate incompatible logs, reseting replica"); - (reset)(ResetOp::Reset(namespace.clone())); - Err(err)?; - } - Error::InvalidMetaFile - | Error::Io(_) - | Error::InvalidLogId - | Error::FailedToCommit(_) - | Error::InvalidReplicationPath - | Error::RequiresCleanDatabase => { - // We retry from last frame index? - tracing::warn!("non-fatal replication error, retrying from last commit index: {err}"); - }, - } - } - e @ (Error::Internal(_) - | Error::Client(_) - | Error::PrimaryHandshakeTimeout - | Error::NeedSnapshot) => { - tracing::warn!("non-fatal replication error, retrying from last commit index: {e}"); - }, - Error::NoHandshake => { - // not strictly necessary, but in case the handshake error goes uncaught, - // we reset the client state. - replicator.client_mut().reset_token(); - } - Error::SnapshotPending => unreachable!(), - } - } - }); - - let stats = make_stats( - &db_path, - &mut join_set, - meta_store_handle.clone(), - config.stats_sender.clone(), - name.clone(), - applied_frame_no_receiver.clone(), - config.encryption_config.clone(), - ) - .await?; - - let connection_maker = MakeWriteProxyConn::new( - db_path.clone(), - config.extensions.clone(), - channel.clone(), - uri.clone(), - stats.clone(), - broadcaster, - meta_store_handle.clone(), - applied_frame_no_receiver, - config.max_response_size, - config.max_total_response_size, - primary_current_replicatio_index, - config.encryption_config.clone(), - resolve_attach_path, - config.make_wal_manager.clone(), - ) - .await? - .throttled( - config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - config.max_total_response_size, - config.max_concurrent_requests, - ); - - Ok(Self { - tasks: join_set, - db: Database::Replica(ReplicaDatabase { - connection_maker: Arc::new(connection_maker), - }), - name, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } - async fn fork( ns_config: &NamespaceConfig, from_ns: &Namespace, @@ -655,9 +338,7 @@ impl Namespace { to_ns: NamespaceName, to_config: MetaStoreHandle, timestamp: Option, - resolve_attach: ResolveNamespacePathFn, store: NamespaceStore, - broadcaster: BroadcasterHandle, ) -> crate::Result { let from_config = from_config.get(); match ns_config.db_kind { @@ -696,10 +377,7 @@ impl Namespace { logger, restore_to, to_config, - ns_config, - resolve_attach, store, - broadcaster: broadcaster.handle(to_ns), }; let ns = fork_task.fork().await?; @@ -708,48 +386,6 @@ impl Namespace { DatabaseKind::Replica => Err(ForkError::ForkReplica.into()), } } - - async fn new_schema( - ns_config: &NamespaceConfig, - name: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result { - let mut join_set = JoinSet::new(); - let db_path = ns_config.base_path.join("dbs").join(name.as_str()); - - tokio::fs::create_dir_all(&db_path).await?; - - let (connection_maker, wal_manager, stats) = Self::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &name, - restore_option, - Arc::new(AtomicBool::new(false)), // this is always false for schema - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - - Ok(Namespace { - db: Database::Schema(SchemaDatabase::new( - ns_config.migration_scheduler.clone(), - name.clone(), - connection_maker, - wal_manager, - meta_store_handle.clone(), - )), - name, - tasks: join_set, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) - } } pub struct NamespaceConfig { diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index 984a520154..5a94a7f8eb 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -13,8 +13,10 @@ use tokio_stream::wrappers::BroadcastStream; use crate::auth::Authenticated; use crate::broadcaster::BroadcastMsg; use crate::connection::config::DatabaseConfig; +use crate::database::DatabaseKind; use crate::error::Error; use crate::metrics::NAMESPACE_LOAD_LATENCY; +use crate::namespace::configurator::{PrimaryConfigurator, ReplicaConfigurator, SchemaConfigurator}; use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName}; use crate::stats::Stats; @@ -82,6 +84,12 @@ impl NamespaceStore { .time_to_idle(Duration::from_secs(86400)) .build(); + let mut configurators = NamespaceConfigurators::default(); + configurators + .with_primary(PrimaryConfigurator) + .with_replica(ReplicaConfigurator) + .with_schema(SchemaConfigurator); + Ok(Self { inner: Arc::new(NamespaceStoreInner { store, @@ -92,7 +100,7 @@ impl NamespaceStore { config, schema_locks: Default::default(), broadcasters: Default::default(), - configurators: NamespaceConfigurators::default(), + configurators, }), }) } @@ -177,27 +185,17 @@ impl NamespaceStore { ns.destroy().await?; } - let handle = self.inner.metadata.handle(namespace.clone()); + let db_config = self.inner.metadata.handle(namespace.clone()); // destroy on-disk database Namespace::cleanup( &self.inner.config, &namespace, - &handle.get(), + &db_config.get(), false, NamespaceBottomlessDbIdInit::FetchFromConfig, ) .await?; - let ns = Namespace::from_config( - &self.inner.config, - handle, - restore_option, - &namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ) - .await?; + let ns = self.make_namespace(&namespace, db_config, restore_option).await?; lock.replace(ns); @@ -304,9 +302,7 @@ impl NamespaceStore { to.clone(), handle.clone(), timestamp, - self.resolve_attach_fn(), self.clone(), - self.broadcaster(to), ) .await?; @@ -381,30 +377,42 @@ impl NamespaceStore { .clone() } + pub(crate) async fn make_namespace( + &self, + namespace: &NamespaceName, + config: MetaStoreHandle, + restore_option: RestoreOption, + ) -> crate::Result { + let configurator = match self.inner.config.db_kind { + DatabaseKind::Primary if config.get().is_shared_schema => { + self.inner.configurators.configure_schema()? + } + DatabaseKind::Primary => self.inner.configurators.configure_primary()?, + DatabaseKind::Replica => self.inner.configurators.configure_replica()?, + }; + let ns = configurator.setup( + &self.inner.config, + config, + restore_option, + namespace, + self.make_reset_cb(), + self.resolve_attach_fn(), + self.clone(), + self.broadcaster(namespace.clone()), + ).await?; + + Ok(ns) + } + async fn load_namespace( &self, namespace: &NamespaceName, db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - let init = { - let namespace = namespace.clone(); - async move { - let ns = Namespace::from_config( - &self.inner.config, - db_config, - restore_option, - &namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ) - .await?; - tracing::info!("loaded namespace: `{namespace}`"); - - Ok(Some(ns)) - } + let init = async { + let ns = self.make_namespace(namespace, db_config, restore_option).await?; + Ok(Some(ns)) }; let before_load = Instant::now(); From 907f2f9381783b09254e605473455c72e97cd40b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 5 Aug 2024 11:14:40 +0200 Subject: [PATCH 07/70] pass configurators to NamespaceStore::new --- libsql-server/src/lib.rs | 5 +++ .../src/namespace/configurator/mod.rs | 38 ++++++++++++------- libsql-server/src/namespace/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 10 +---- libsql-server/src/schema/scheduler.rs | 38 ++++++++++++++----- 5 files changed, 62 insertions(+), 31 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 5404a11108..3d816d6bc3 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -60,6 +60,7 @@ use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; use self::connection::connection_manager::InnerWalManager; +use self::namespace::configurator::NamespaceConfigurators; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; @@ -488,12 +489,16 @@ where meta_store_wal_manager, ) .await?; + + let configurators = NamespaceConfigurators::default(); + let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, self.max_active_namespaces, ns_config, meta_store, + configurators, ) .await?; diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index d3cd390b34..a240c3e410 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -4,43 +4,55 @@ use futures::Future; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use super::{ + NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, +}; -mod replica; mod primary; +mod replica; mod schema; -pub use replica::ReplicaConfigurator; pub use primary::PrimaryConfigurator; +pub use replica::ReplicaConfigurator; pub use schema::SchemaConfigurator; type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; -#[derive(Default)] pub(crate) struct NamespaceConfigurators { replica_configurator: Option>, primary_configurator: Option>, schema_configurator: Option>, } +impl Default for NamespaceConfigurators { + fn default() -> Self { + Self::empty() + .with_primary(PrimaryConfigurator) + .with_replica(ReplicaConfigurator) + .with_schema(SchemaConfigurator) + } +} + impl NamespaceConfigurators { - pub fn with_primary( - &mut self, - c: impl ConfigureNamespace + Send + Sync + 'static, - ) -> &mut Self { + pub fn empty() -> Self { + Self { + replica_configurator: None, + primary_configurator: None, + schema_configurator: None, + } + } + + pub fn with_primary(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica( - &mut self, - c: impl ConfigureNamespace + Send + Sync + 'static, - ) -> &mut Self { + pub fn with_replica(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.replica_configurator = Some(Box::new(c)); self } - pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_schema(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { self.schema_configurator = Some(Box::new(c)); self } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 41bb3ab9cc..5ccda74c54 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -52,7 +52,7 @@ mod name; pub mod replication_wal; mod schema_lock; mod store; -mod configurator; +pub(crate) mod configurator; pub type ResetCb = Box; pub type ResolveNamespacePathFn = diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index 5a94a7f8eb..fbce8cd78b 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -16,7 +16,6 @@ use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; use crate::error::Error; use crate::metrics::NAMESPACE_LOAD_LATENCY; -use crate::namespace::configurator::{PrimaryConfigurator, ReplicaConfigurator, SchemaConfigurator}; use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName}; use crate::stats::Stats; @@ -54,12 +53,13 @@ pub struct NamespaceStoreInner { } impl NamespaceStore { - pub async fn new( + pub(crate) async fn new( allow_lazy_creation: bool, snapshot_at_shutdown: bool, max_active_namespaces: usize, config: NamespaceConfig, metadata: MetaStore, + configurators: NamespaceConfigurators, ) -> crate::Result { tracing::trace!("Max active namespaces: {max_active_namespaces}"); let store = Cache::::builder() @@ -84,12 +84,6 @@ impl NamespaceStore { .time_to_idle(Duration::from_secs(86400)) .build(); - let mut configurators = NamespaceConfigurators::default(); - configurators - .with_primary(PrimaryConfigurator) - .with_replica(ReplicaConfigurator) - .with_schema(SchemaConfigurator); - Ok(Self { inner: Arc::new(NamespaceStoreInner { store, diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 17fdfb3143..17ce655064 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -808,6 +808,9 @@ mod test { use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; + use crate::namespace::configurator::{ + NamespaceConfigurators, PrimaryConfigurator, SchemaConfigurator, + }; use crate::namespace::meta_store::{metastore_connection_maker, MetaStore}; use crate::namespace::{NamespaceConfig, RestoreOption}; use crate::schema::SchedulerHandle; @@ -826,9 +829,16 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + config, + meta_store, + NamespaceConfigurators::default(), + ) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -936,9 +946,16 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + config, + meta_store, + NamespaceConfigurators::default(), + ) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1012,7 +1029,7 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) .await .unwrap(); @@ -1039,7 +1056,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let configurators = NamespaceConfigurators::default() + .with_schema(SchemaConfigurator) + .with_primary(PrimaryConfigurator); + let store = NamespaceStore::new(false, false, 10, config, meta_store, configurators) .await .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) @@ -1112,7 +1132,7 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store) + let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) .await .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) From fd03144bcc30e7b85f07abf4630ba9ea58a87d11 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 5 Aug 2024 15:26:34 +0200 Subject: [PATCH 08/70] decoupled namespace configurators --- libsql-server/src/error.rs | 2 +- libsql-server/src/lib.rs | 86 ++- .../src/namespace/{ => configurator}/fork.rs | 62 +- .../src/namespace/configurator/helpers.rs | 451 ++++++++++++++ .../src/namespace/configurator/mod.rs | 67 ++- .../src/namespace/configurator/primary.rs | 249 +++++--- .../src/namespace/configurator/replica.rs | 140 +++-- .../src/namespace/configurator/schema.rs | 71 ++- libsql-server/src/namespace/mod.rs | 551 +----------------- libsql-server/src/namespace/store.rs | 103 ++-- libsql-server/src/schema/scheduler.rs | 80 ++- 11 files changed, 1056 insertions(+), 806 deletions(-) rename libsql-server/src/namespace/{ => configurator}/fork.rs (77%) create mode 100644 libsql-server/src/namespace/configurator/helpers.rs diff --git a/libsql-server/src/error.rs b/libsql-server/src/error.rs index 371630abdf..9cd0b81485 100644 --- a/libsql-server/src/error.rs +++ b/libsql-server/src/error.rs @@ -4,7 +4,7 @@ use tonic::metadata::errors::InvalidMetadataValueBytes; use crate::{ auth::AuthError, - namespace::{ForkError, NamespaceName}, + namespace::{configurator::fork::ForkError, NamespaceName}, query_result_builder::QueryResultBuilderError, }; diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 3d816d6bc3..8bd3ea4fac 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -46,7 +46,7 @@ use libsql_wal::registry::WalRegistry; use libsql_wal::storage::NoStorage; use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; -use namespace::{NamespaceConfig, NamespaceName}; +use namespace::NamespaceName; use net::Connector; use once_cell::sync::Lazy; use rusqlite::ffi::SQLITE_CONFIG_MALLOC; @@ -60,7 +60,7 @@ use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; use self::connection::connection_manager::InnerWalManager; -use self::namespace::configurator::NamespaceConfigurators; +use self::namespace::configurator::{BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator}; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; @@ -425,11 +425,6 @@ where let user_auth_strategy = self.user_api_config.auth_strategy.clone(); let service_shutdown = Arc::new(Notify::new()); - let db_kind = if self.rpc_client_config.is_some() { - DatabaseKind::Replica - } else { - DatabaseKind::Primary - }; let scripted_backup = match self.db_config.snapshot_exec { Some(ref command) => { @@ -457,27 +452,6 @@ where // chose the wal backend let (make_wal_manager, registry_shutdown) = self.configure_wal_manager(&mut join_set)?; - let ns_config = NamespaceConfig { - db_kind, - base_path: self.path.clone(), - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - extensions, - stats_sender: stats_sender.clone(), - max_response_size: self.db_config.max_response_size, - max_total_response_size: self.db_config.max_total_response_size, - checkpoint_interval: self.db_config.checkpoint_interval, - encryption_config: self.db_config.encryption_config.clone(), - max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), - scripted_backup, - max_concurrent_requests: self.db_config.max_concurrent_requests, - channel: channel.clone(), - uri: uri.clone(), - migration_scheduler: scheduler_sender.into(), - make_wal_manager, - }; - let (metastore_conn_maker, meta_store_wal_manager) = metastore_connection_maker(self.meta_store_config.bottomless.clone(), &self.path) .await?; @@ -490,15 +464,67 @@ where ) .await?; - let configurators = NamespaceConfigurators::default(); + let base_config = BaseNamespaceConfig { + base_path: self.path.clone(), + extensions, + stats_sender, + max_response_size: self.db_config.max_response_size, + max_total_response_size: self.db_config.max_total_response_size, + max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), + max_concurrent_requests: self.db_config.max_concurrent_requests, + }; + + let mut configurators = NamespaceConfigurators::default(); + + let db_kind = match channel.clone().zip(uri.clone()) { + // replica mode + Some((channel, uri)) => { + let replica_configurator = ReplicaConfigurator::new( + base_config, + channel, + uri, + make_wal_manager, + ); + configurators.with_replica(replica_configurator); + DatabaseKind::Replica + } + // primary mode + None => { + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; + + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); + + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + scheduler_sender.into(), + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); + + DatabaseKind::Primary + }, + }; let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, self.max_active_namespaces, - ns_config, meta_store, configurators, + db_kind, ) .await?; diff --git a/libsql-server/src/namespace/fork.rs b/libsql-server/src/namespace/configurator/fork.rs similarity index 77% rename from libsql-server/src/namespace/fork.rs rename to libsql-server/src/namespace/configurator/fork.rs index f25bf7a9a9..26a0b99b61 100644 --- a/libsql-server/src/namespace/fork.rs +++ b/libsql-server/src/namespace/configurator/fork.rs @@ -12,15 +12,71 @@ use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::time::Duration; use tokio_stream::StreamExt; +use crate::database::Database; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{Namespace, NamespaceBottomlessDbId}; use crate::replication::primary::frame_stream::FrameStream; use crate::replication::{LogReadError, ReplicationLogger}; use crate::{BLOCKING_RT, LIBSQL_PAGE_SIZE}; -use super::meta_store::MetaStoreHandle; -use super::{NamespaceName, NamespaceStore, RestoreOption}; +use super::helpers::make_bottomless_options; +use super::{NamespaceName, NamespaceStore, PrimaryExtraConfig, RestoreOption}; type Result = crate::Result; +pub(super) async fn fork( + from_ns: &Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + primary_config: &PrimaryExtraConfig, + base_path: Arc, +) -> crate::Result { + let from_config = from_config.get(); + let bottomless_db_id = NamespaceBottomlessDbId::from_config(&from_config); + let restore_to = if let Some(timestamp) = timestamp { + if let Some(ref options) = primary_config.bottomless_replication { + Some(PointInTimeRestore { + timestamp, + replicator_options: make_bottomless_options( + options, + bottomless_db_id.clone(), + from_ns.name().clone(), + ), + }) + } else { + return Err(crate::Error::Fork(ForkError::BackupServiceNotConfigured)); + } + } else { + None + }; + + let logger = match &from_ns.db { + Database::Primary(db) => db.wal_wrapper.wrapper().logger(), + Database::Schema(db) => db.wal_wrapper.wrapper().logger(), + _ => { + return Err(crate::Error::Fork(ForkError::Internal(anyhow::Error::msg( + "Invalid source database type for fork", + )))); + } + }; + + let fork_task = ForkTask { + base_path, + to_namespace: to_ns.clone(), + logger, + restore_to, + to_config, + store, + }; + + let ns = fork_task.fork().await?; + + Ok(ns) +} + #[derive(Debug, thiserror::Error)] pub enum ForkError { #[error("internal error: {0}")] @@ -58,7 +114,7 @@ pub struct ForkTask { pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub store: NamespaceStore, + pub store: NamespaceStore } pub struct PointInTimeRestore { diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs new file mode 100644 index 0000000000..f43fa8a192 --- /dev/null +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -0,0 +1,451 @@ +use std::path::{Path, PathBuf}; +use std::sync::Weak; +use std::sync::{atomic::AtomicBool, Arc}; +use std::time::Duration; + +use anyhow::Context as _; +use bottomless::replicator::Options; +use bytes::Bytes; +use futures::Stream; +use libsql_sys::wal::Sqlite3WalManager; +use tokio::io::AsyncBufReadExt as _; +use tokio::sync::watch; +use tokio::task::JoinSet; +use tokio_util::io::StreamReader; +use enclose::enclose; + +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; +use crate::connection::libsql::{open_conn, MakeLibSqlConn}; +use crate::connection::{Connection as _, MakeConnection as _}; +use crate::error::LoadDumpError; +use crate::replication::{FrameNo, ReplicationLogger}; +use crate::stats::Stats; +use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, RestoreOption}; +use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; +use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; + +use super::{BaseNamespaceConfig, PrimaryExtraConfig}; + +const WASM_TABLE_CREATE: &str = + "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;"; + +#[tracing::instrument(skip_all)] +pub(super) async fn make_primary_connection_maker( + primary_config: &PrimaryExtraConfig, + base_config: &BaseNamespaceConfig, + meta_store_handle: &MetaStoreHandle, + db_path: &Path, + name: &NamespaceName, + restore_option: RestoreOption, + block_writes: Arc, + join_set: &mut JoinSet>, + resolve_attach_path: ResolveNamespacePathFn, + broadcaster: BroadcasterHandle, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { + let db_config = meta_store_handle.get(); + let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); + // FIXME: figure how to to it per-db + let mut is_dirty = { + let sentinel_path = db_path.join(".sentinel"); + if sentinel_path.try_exists()? { + true + } else { + tokio::fs::File::create(&sentinel_path).await?; + false + } + }; + + // FIXME: due to a bug in logger::checkpoint_db we call regular checkpointing code + // instead of our virtual WAL one. It's a bit tangled to fix right now, because + // we need WAL context for checkpointing, and WAL context needs the ReplicationLogger... + // So instead we checkpoint early, *before* bottomless gets initialized. That way + // we're sure bottomless won't try to back up any existing WAL frames and will instead + // treat the existing db file as the source of truth. + + let bottomless_replicator = match primary_config.bottomless_replication { + Some(ref options) => { + tracing::debug!("Checkpointing before initializing bottomless"); + crate::replication::primary::logger::checkpoint_db(&db_path.join("data"))?; + tracing::debug!("Checkpointed before initializing bottomless"); + let options = make_bottomless_options(options, bottomless_db_id, name.clone()); + let (replicator, did_recover) = + init_bottomless_replicator(db_path.join("data"), options, &restore_option) + .await?; + tracing::debug!("Completed init of bottomless replicator"); + is_dirty |= did_recover; + Some(replicator) + } + None => None, + }; + + tracing::debug!("Checking fresh db"); + let is_fresh_db = check_fresh_db(&db_path)?; + // switch frame-count checkpoint to time-based one + let auto_checkpoint = if primary_config.checkpoint_interval.is_some() { + 0 + } else { + DEFAULT_AUTO_CHECKPOINT + }; + + let logger = Arc::new(ReplicationLogger::open( + &db_path, + primary_config.max_log_size, + primary_config.max_log_duration, + is_dirty, + auto_checkpoint, + primary_config.scripted_backup.clone(), + name.clone(), + None, + )?); + + tracing::debug!("sending stats"); + + let stats = make_stats( + &db_path, + join_set, + meta_store_handle.clone(), + base_config.stats_sender.clone(), + name.clone(), + logger.new_frame_notifier.subscribe(), + ) + .await?; + + tracing::debug!("Making replication wal wrapper"); + let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); + + tracing::debug!("Opening libsql connection"); + + let connection_maker = MakeLibSqlConn::new( + db_path.to_path_buf(), + wal_wrapper.clone(), + stats.clone(), + broadcaster, + meta_store_handle.clone(), + base_config.extensions.clone(), + base_config.max_response_size, + base_config.max_total_response_size, + auto_checkpoint, + logger.new_frame_notifier.subscribe(), + None, + block_writes, + resolve_attach_path, + make_wal_manager.clone(), + ) + .await? + .throttled( + base_config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + base_config.max_total_response_size, + base_config.max_concurrent_requests, + ); + + tracing::debug!("Completed opening libsql connection"); + + // this must happen after we create the connection maker. The connection maker old on a + // connection to ensure that no other connection is closing while we try to open the dump. + // that would cause a SQLITE_LOCKED error. + match restore_option { + RestoreOption::Dump(_) if !is_fresh_db => { + Err(LoadDumpError::LoadDumpExistingDb)?; + } + RestoreOption::Dump(dump) => { + let conn = connection_maker.create().await?; + tracing::debug!("Loading dump"); + load_dump(dump, conn).await?; + tracing::debug!("Done loading dump"); + } + _ => { /* other cases were already handled when creating bottomless */ } + } + + join_set.spawn(run_periodic_compactions(logger.clone())); + + tracing::debug!("Done making primary connection"); + + Ok((connection_maker, wal_wrapper, stats)) +} + +pub(super) fn make_bottomless_options( + options: &Options, + namespace_db_id: NamespaceBottomlessDbId, + name: NamespaceName, +) -> Options { + let mut options = options.clone(); + let mut db_id = match namespace_db_id { + NamespaceBottomlessDbId::Namespace(id) => id, + // FIXME(marin): I don't like that, if bottomless is enabled, proper config must be passed. + NamespaceBottomlessDbId::NotProvided => options.db_id.unwrap_or_default(), + }; + + db_id = format!("ns-{db_id}:{name}"); + options.db_id = Some(db_id); + options +} + +async fn init_bottomless_replicator( + path: impl AsRef, + options: bottomless::replicator::Options, + restore_option: &RestoreOption, +) -> anyhow::Result<(bottomless::replicator::Replicator, bool)> { + tracing::debug!("Initializing bottomless replication"); + let path = path + .as_ref() + .to_str() + .ok_or_else(|| anyhow::anyhow!("Invalid db path"))? + .to_owned(); + let mut replicator = bottomless::replicator::Replicator::with_options(path, options).await?; + + let (generation, timestamp) = match restore_option { + RestoreOption::Latest | RestoreOption::Dump(_) => (None, None), + RestoreOption::Generation(generation) => (Some(*generation), None), + RestoreOption::PointInTime(timestamp) => (None, Some(*timestamp)), + }; + + let (action, did_recover) = replicator.restore(generation, timestamp).await?; + match action { + bottomless::replicator::RestoreAction::SnapshotMainDbFile => { + replicator.new_generation().await; + if let Some(_handle) = replicator.snapshot_main_db_file(true).await? { + tracing::trace!("got snapshot handle after restore with generation upgrade"); + } + // Restoration process only leaves the local WAL file if it was + // detected to be newer than its remote counterpart. + replicator.maybe_replicate_wal().await? + } + bottomless::replicator::RestoreAction::ReuseGeneration(gen) => { + replicator.set_generation(gen); + } + } + + Ok((replicator, did_recover)) +} + +async fn run_periodic_compactions(logger: Arc) -> anyhow::Result<()> { + // calling `ReplicationLogger::maybe_compact()` is cheap if the compaction does not actually + // take place, so we can afford to poll it very often for simplicity + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(1000)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + interval.tick().await; + let handle = BLOCKING_RT.spawn_blocking(enclose! {(logger) move || { + logger.maybe_compact() + }}); + handle + .await + .expect("Compaction task crashed") + .context("Compaction failed")?; + } +} + +async fn load_dump(dump: S, conn: PrimaryConnection) -> crate::Result<(), LoadDumpError> +where + S: Stream> + Unpin, +{ + let mut reader = tokio::io::BufReader::new(StreamReader::new(dump)); + let mut curr = String::new(); + let mut line = String::new(); + let mut skipped_wasm_table = false; + let mut n_stmt = 0; + let mut line_id = 0; + + while let Ok(n) = reader.read_line(&mut curr).await { + line_id += 1; + if n == 0 { + break; + } + let trimmed = curr.trim(); + if trimmed.is_empty() || trimmed.starts_with("--") { + curr.clear(); + continue; + } + // FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code + let statement_end = trimmed.ends_with(';'); + + // we want to concat original(non-trimmed) lines as trimming will join all them in one + // single-line statement which is incorrect if comments in the end are present + line.push_str(&curr); + curr.clear(); + + // This is a hack to ignore the libsql_wasm_func_table table because it is already created + // by the system. + if !skipped_wasm_table && line.trim() == WASM_TABLE_CREATE { + skipped_wasm_table = true; + line.clear(); + continue; + } + + if statement_end { + n_stmt += 1; + // dump must be performd within a txn + if n_stmt > 2 && conn.is_autocommit().await.unwrap() { + return Err(LoadDumpError::NoTxn); + } + + line = tokio::task::spawn_blocking({ + let conn = conn.clone(); + move || -> crate::Result { + conn.with_raw(|conn| conn.execute(&line, ())).map_err(|e| { + LoadDumpError::Internal(format!("line: {}, error: {}", line_id, e)) + })?; + Ok(line) + } + }) + .await??; + line.clear(); + } else { + line.push(' '); + } + } + tracing::debug!("loaded {} lines from dump", line_id); + + if !conn.is_autocommit().await.unwrap() { + tokio::task::spawn_blocking({ + let conn = conn.clone(); + move || -> crate::Result<(), LoadDumpError> { + conn.with_raw(|conn| conn.execute("rollback", ()))?; + Ok(()) + } + }) + .await??; + return Err(LoadDumpError::NoCommit); + } + + Ok(()) +} + +fn check_fresh_db(path: &Path) -> crate::Result { + let is_fresh = !path.join("wallog").try_exists()?; + Ok(is_fresh) +} + +pub(super) async fn make_stats( + db_path: &Path, + join_set: &mut JoinSet>, + meta_store_handle: MetaStoreHandle, + stats_sender: StatsSender, + name: NamespaceName, + mut current_frame_no: watch::Receiver>, +) -> anyhow::Result> { + tracing::debug!("creating stats type"); + let stats = Stats::new(name.clone(), db_path, join_set).await?; + + // the storage monitor is optional, so we ignore the error here. + tracing::debug!("stats created, sending stats"); + let _ = stats_sender + .send((name.clone(), meta_store_handle, Arc::downgrade(&stats))) + .await; + + join_set.spawn({ + let stats = stats.clone(); + // initialize the current_frame_no value + current_frame_no + .borrow_and_update() + .map(|fno| stats.set_current_frame_no(fno)); + async move { + while current_frame_no.changed().await.is_ok() { + current_frame_no + .borrow_and_update() + .map(|fno| stats.set_current_frame_no(fno)); + } + Ok(()) + } + }); + + join_set.spawn(run_storage_monitor( + db_path.into(), + Arc::downgrade(&stats), + )); + + tracing::debug!("done sending stats, and creating bg tasks"); + + Ok(stats) +} + +// Periodically check the storage used by the database and save it in the Stats structure. +// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine +// right after checkpointing is exactly where it should be done. +async fn run_storage_monitor( + db_path: PathBuf, + stats: Weak, +) -> anyhow::Result<()> { + // on initialization, the database file doesn't exist yet, so we wait a bit for it to be + // created + tokio::time::sleep(Duration::from_secs(1)).await; + + let duration = tokio::time::Duration::from_secs(60); + let db_path: Arc = db_path.into(); + loop { + let db_path = db_path.clone(); + let Some(stats) = stats.upgrade() else { + return Ok(()); + }; + + let _ = tokio::task::spawn_blocking(move || { + // because closing the last connection interferes with opening a new one, we lazily + // initialize a connection here, and keep it alive for the entirety of the program. If we + // fail to open it, we wait for `duration` and try again later. + match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), None) { + Ok(mut conn) => { + if let Ok(tx) = conn.transaction() { + let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); + let freelist_count = tx.query_row("pragma freelist_count;", [], |row| { row.get::(0) }); + if let (Ok(page_count), Ok(freelist_count)) = (page_count, freelist_count) { + let storage_bytes_used = (page_count - freelist_count) * 4096; + stats.set_storage_bytes_used(storage_bytes_used); + } + } + }, + Err(e) => { + tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); + }, + } + }).await; + + tokio::time::sleep(duration).await; + } +} + +pub(super) async fn cleanup_primary( + base: &BaseNamespaceConfig, + primary_config: &PrimaryExtraConfig, + namespace: &NamespaceName, + db_config: &DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, +) -> crate::Result<()> { + let ns_path = base.base_path.join("dbs").join(namespace.as_str()); + if let Some(ref options) = primary_config.bottomless_replication { + let bottomless_db_id = match bottomless_db_id_init { + NamespaceBottomlessDbIdInit::Provided(db_id) => db_id, + NamespaceBottomlessDbIdInit::FetchFromConfig => { + NamespaceBottomlessDbId::from_config(db_config) + } + }; + let options = make_bottomless_options(options, bottomless_db_id, namespace.clone()); + let replicator = bottomless::replicator::Replicator::with_options( + ns_path.join("data").to_str().unwrap(), + options, + ) + .await?; + if prune_all { + let delete_all = replicator.delete_all(None).await?; + // perform hard deletion in the background + tokio::spawn(delete_all.commit()); + } else { + // for soft delete make sure that local db is fully backed up + replicator.savepoint().confirmed().await?; + } + } + + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + + Ok(()) +} diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index a240c3e410..e5db335ff6 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -1,22 +1,51 @@ +use std::path::{Path, PathBuf}; use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use chrono::NaiveDateTime; use futures::Future; +use tokio::sync::Semaphore; + +use crate::connection::config::DatabaseConfig; +use crate::replication::script_backup_manager::ScriptBackupManager; +use crate::StatsSender; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{ - NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, -}; +use super::{Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +mod helpers; mod primary; mod replica; mod schema; +pub mod fork; pub use primary::PrimaryConfigurator; pub use replica::ReplicaConfigurator; pub use schema::SchemaConfigurator; -type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; +#[derive(Clone, Debug)] +pub struct BaseNamespaceConfig { + pub(crate) base_path: Arc, + pub(crate) extensions: Arc<[PathBuf]>, + pub(crate) stats_sender: StatsSender, + pub(crate) max_response_size: u64, + pub(crate) max_total_response_size: u64, + pub(crate) max_concurrent_connections: Arc, + pub(crate) max_concurrent_requests: u64, +} + +#[derive(Clone)] +pub struct PrimaryExtraConfig { + pub(crate) max_log_size: u64, + pub(crate) max_log_duration: Option, + pub(crate) bottomless_replication: Option, + pub(crate) scripted_backup: Option, + pub(crate) checkpoint_interval: Option, +} + +pub type DynConfigurator = dyn ConfigureNamespace + Send + Sync + 'static; pub(crate) struct NamespaceConfigurators { replica_configurator: Option>, @@ -27,9 +56,6 @@ pub(crate) struct NamespaceConfigurators { impl Default for NamespaceConfigurators { fn default() -> Self { Self::empty() - .with_primary(PrimaryConfigurator) - .with_replica(ReplicaConfigurator) - .with_schema(SchemaConfigurator) } } @@ -42,17 +68,17 @@ impl NamespaceConfigurators { } } - pub fn with_primary(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_primary(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_replica(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.replica_configurator = Some(Box::new(c)); self } - pub fn with_schema(mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> Self { + pub fn with_schema(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { self.schema_configurator = Some(Box::new(c)); self } @@ -73,7 +99,6 @@ impl NamespaceConfigurators { pub trait ConfigureNamespace { fn setup<'a>( &'a self, - ns_config: &'a NamespaceConfig, db_config: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -81,5 +106,23 @@ pub trait ConfigureNamespace { resolve_attach_path: ResolveNamespacePathFn, store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>>; + ) -> Pin> + Send + 'a>>; + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>>; + + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> Pin> + Send + 'a>>; } diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index f28d288a97..4351f6a3ac 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -4,22 +4,117 @@ use std::{path::Path, pin::Pin, sync::Arc}; use futures::prelude::Future; use tokio::task::JoinSet; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::connection::MakeConnection; use crate::database::{Database, PrimaryDatabase}; -use crate::namespace::{Namespace, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; -use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_primary_connection_maker; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, + ResetCb, ResolveNamespacePathFn, RestoreOption, +}; use crate::run_periodic_checkpoint; use crate::schema::{has_pending_migration_task, setup_migration_table}; -use super::ConfigureNamespace; +use super::helpers::cleanup_primary; +use super::{BaseNamespaceConfig, ConfigureNamespace, PrimaryExtraConfig}; + +pub struct PrimaryConfigurator { + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl PrimaryConfigurator { + pub fn new( + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + ) -> Self { + Self { + base, + primary_config, + make_wal_manager, + } + } + + #[tracing::instrument(skip_all, fields(namespace))] + async fn try_new_primary( + &self, + namespace: NamespaceName, + meta_store_handle: MetaStoreHandle, + restore_option: RestoreOption, + resolve_attach_path: ResolveNamespacePathFn, + db_path: Arc, + broadcaster: BroadcasterHandle, + ) -> crate::Result { + let mut join_set = JoinSet::new(); + + tokio::fs::create_dir_all(&db_path).await?; + + let block_writes = Arc::new(AtomicBool::new(false)); + let (connection_maker, wal_wrapper, stats) = make_primary_connection_maker( + &self.primary_config, + &self.base, + &meta_store_handle, + &db_path, + &namespace, + restore_option, + block_writes.clone(), + &mut join_set, + resolve_attach_path, + broadcaster, + self.make_wal_manager.clone(), + ) + .await?; + let connection_maker = Arc::new(connection_maker); + + if meta_store_handle.get().shared_schema_name.is_some() { + let block_writes = block_writes.clone(); + let conn = connection_maker.create().await?; + tokio::task::spawn_blocking(move || { + conn.with_raw(|conn| -> crate::Result<()> { + setup_migration_table(conn)?; + if has_pending_migration_task(conn)? { + block_writes.store(true, Ordering::SeqCst); + } + Ok(()) + }) + }) + .await + .unwrap()?; + } + + if let Some(checkpoint_interval) = self.primary_config.checkpoint_interval { + join_set.spawn(run_periodic_checkpoint( + connection_maker.clone(), + checkpoint_interval, + namespace.clone(), + )); + } + + tracing::debug!("Done making new primary"); -pub struct PrimaryConfigurator; + Ok(Namespace { + tasks: join_set, + db: Database::Primary(PrimaryDatabase { + wal_wrapper, + connection_maker, + block_writes, + }), + name: namespace, + stats, + db_config_store: meta_store_handle, + path: db_path.into(), + }) + } +} impl ConfigureNamespace for PrimaryConfigurator { fn setup<'a>( &'a self, - config: &'a NamespaceConfig, meta_store_handle: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -27,102 +122,74 @@ impl ConfigureNamespace for PrimaryConfigurator { resolve_attach_path: ResolveNamespacePathFn, _store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>> - { + ) -> Pin> + Send + 'a>> { Box::pin(async move { - let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); + let db_path: Arc = self.base.base_path.join("dbs").join(name.as_str()).into(); let fresh_namespace = !db_path.try_exists()?; // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications - match try_new_primary( - config, - name.clone(), - meta_store_handle, - restore_option, - resolve_attach_path, - db_path.clone(), - broadcaster, - ) + match self + .try_new_primary( + name.clone(), + meta_store_handle, + restore_option, + resolve_attach_path, + db_path.clone(), + broadcaster, + ) .await - { - Ok(this) => Ok(this), - Err(e) if fresh_namespace => { - tracing::error!("an error occured while deleting creating namespace, cleaning..."); - if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { - tracing::error!("failed to remove dirty namespace directory: {e}") - } - Err(e) + { + Ok(this) => Ok(this), + Err(e) if fresh_namespace => { + tracing::error!( + "an error occured while deleting creating namespace, cleaning..." + ); + if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { + tracing::error!("failed to remove dirty namespace directory: {e}") } - Err(e) => Err(e), + Err(e) } + Err(e) => Err(e), + } }) } -} -#[tracing::instrument(skip_all, fields(namespace))] -async fn try_new_primary( - ns_config: &NamespaceConfig, - namespace: NamespaceName, - meta_store_handle: MetaStoreHandle, - restore_option: RestoreOption, - resolve_attach_path: ResolveNamespacePathFn, - db_path: Arc, - broadcaster: BroadcasterHandle, -) -> crate::Result { - let mut join_set = JoinSet::new(); - - tokio::fs::create_dir_all(&db_path).await?; - - let block_writes = Arc::new(AtomicBool::new(false)); - let (connection_maker, wal_wrapper, stats) = Namespace::make_primary_connection_maker( - ns_config, - &meta_store_handle, - &db_path, - &namespace, - restore_option, - block_writes.clone(), - &mut join_set, - resolve_attach_path, - broadcaster, - ) - .await?; - let connection_maker = Arc::new(connection_maker); - - if meta_store_handle.get().shared_schema_name.is_some() { - let block_writes = block_writes.clone(); - let conn = connection_maker.create().await?; - tokio::task::spawn_blocking(move || { - conn.with_raw(|conn| -> crate::Result<()> { - setup_migration_table(conn)?; - if has_pending_migration_task(conn)? { - block_writes.store(true, Ordering::SeqCst); - } - Ok(()) - }) + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + cleanup_primary( + &self.base, + &self.primary_config, + namespace, + db_config, + prune_all, + bottomless_db_id_init, + ).await }) - .await - .unwrap()?; - } - - if let Some(checkpoint_interval) = ns_config.checkpoint_interval { - join_set.spawn(run_periodic_checkpoint( - connection_maker.clone(), - checkpoint_interval, - namespace.clone(), - )); } - tracing::debug!("Done making new primary"); - - Ok(Namespace { - tasks: join_set, - db: Database::Primary(PrimaryDatabase { - wal_wrapper, - connection_maker, - block_writes, - }), - name: namespace, - stats, - db_config_store: meta_store_handle, - path: db_path.into(), - }) + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(super::fork::fork( + from_ns, + from_config, + to_ns, + to_config, + timestamp, + store, + &self.primary_config, + self.base.base_path.clone())) + } } + diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 4d3ca1dadf..61dd48b0bf 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -2,29 +2,51 @@ use std::pin::Pin; use std::sync::Arc; use futures::Future; +use hyper::Uri; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use tokio::task::JoinSet; +use tonic::transport::Channel; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::connection::write_proxy::MakeWriteProxyConn; use crate::connection::MakeConnection; use crate::database::{Database, ReplicaDatabase}; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_stats; use crate::namespace::meta_store::MetaStoreHandle; -use crate::namespace::{Namespace, RestoreOption}; -use crate::namespace::{ - make_stats, NamespaceConfig, NamespaceName, NamespaceStore, ResetCb, ResetOp, - ResolveNamespacePathFn, -}; +use crate::namespace::{Namespace, NamespaceBottomlessDbIdInit, RestoreOption}; +use crate::namespace::{NamespaceName, NamespaceStore, ResetCb, ResetOp, ResolveNamespacePathFn}; use crate::{DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; -use super::ConfigureNamespace; +use super::{BaseNamespaceConfig, ConfigureNamespace}; -pub struct ReplicaConfigurator; +pub struct ReplicaConfigurator { + base: BaseNamespaceConfig, + channel: Channel, + uri: Uri, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl ReplicaConfigurator { + pub fn new( + base: BaseNamespaceConfig, + channel: Channel, + uri: Uri, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + ) -> Self { + Self { + base, + channel, + uri, + make_wal_manager, + } + } +} impl ConfigureNamespace for ReplicaConfigurator { fn setup<'a>( &'a self, - config: &'a NamespaceConfig, meta_store_handle: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -32,13 +54,12 @@ impl ConfigureNamespace for ReplicaConfigurator { resolve_attach_path: ResolveNamespacePathFn, store: NamespaceStore, broadcaster: BroadcasterHandle, - ) -> Pin> + Send + 'a>> - { + ) -> Pin> + Send + 'a>> { Box::pin(async move { tracing::debug!("creating replica namespace"); - let db_path = config.base_path.join("dbs").join(name.as_str()); - let channel = config.channel.clone().expect("bad replica config"); - let uri = config.uri.clone().expect("bad replica config"); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); + let channel = self.channel.clone(); + let uri = self.uri.clone(); let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); let client = crate::replication::replicator_client::Client::new( @@ -48,45 +69,46 @@ impl ConfigureNamespace for ReplicaConfigurator { meta_store_handle.clone(), store.clone(), ) - .await?; + .await?; let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); let mut replicator = libsql_replication::replicator::Replicator::new( client, db_path.join("data"), DEFAULT_AUTO_CHECKPOINT, - config.encryption_config.clone(), + None, ) - .await?; + .await?; tracing::debug!("try perform handshake"); // force a handshake now, to retrieve the primary's current replication index match replicator.try_perform_handshake().await { Err(libsql_replication::replicator::Error::Meta( - libsql_replication::meta::Error::LogIncompatible, + libsql_replication::meta::Error::LogIncompatible, )) => { tracing::error!( "trying to replicate incompatible logs, reseting replica and nuking db dir" ); std::fs::remove_dir_all(&db_path).unwrap(); - return self.setup( - config, - meta_store_handle, - restore_option, - name, - reset, - resolve_attach_path, - store, - broadcaster, - ) + return self + .setup( + meta_store_handle, + restore_option, + name, + reset, + resolve_attach_path, + store, + broadcaster, + ) .await; - } + } Err(e) => Err(e)?, Ok(_) => (), } tracing::debug!("done performing handshake"); - let primary_current_replicatio_index = replicator.client_mut().primary_replication_index; + let primary_current_replicatio_index = + replicator.client_mut().primary_replication_index; let mut join_set = JoinSet::new(); let namespace = name.clone(); @@ -144,36 +166,35 @@ impl ConfigureNamespace for ReplicaConfigurator { &db_path, &mut join_set, meta_store_handle.clone(), - config.stats_sender.clone(), + self.base.stats_sender.clone(), name.clone(), applied_frame_no_receiver.clone(), - config.encryption_config.clone(), ) - .await?; + .await?; let connection_maker = MakeWriteProxyConn::new( db_path.clone(), - config.extensions.clone(), + self.base.extensions.clone(), channel.clone(), uri.clone(), stats.clone(), broadcaster, meta_store_handle.clone(), applied_frame_no_receiver, - config.max_response_size, - config.max_total_response_size, + self.base.max_response_size, + self.base.max_total_response_size, primary_current_replicatio_index, - config.encryption_config.clone(), + None, resolve_attach_path, - config.make_wal_manager.clone(), + self.make_wal_manager.clone(), ) - .await? - .throttled( - config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - config.max_total_response_size, - config.max_concurrent_requests, - ); + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); Ok(Namespace { tasks: join_set, @@ -187,4 +208,35 @@ impl ConfigureNamespace for ReplicaConfigurator { }) }) } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + _db_config: &DatabaseConfig, + _prune_all: bool, + _bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let ns_path = self.base.base_path.join("dbs").join(namespace.as_str()); + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + Ok(()) + }) + } + + fn fork<'a>( + &'a self, + _from_ns: &'a Namespace, + _from_config: MetaStoreHandle, + _to_ns: NamespaceName, + _to_config: MetaStoreHandle, + _timestamp: Option, + _store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(std::future::ready(Err(crate::Error::Fork( + super::fork::ForkError::ForkReplica, + )))) + } } diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index 864b75239f..e55c706fec 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -3,22 +3,36 @@ use std::sync::{atomic::AtomicBool, Arc}; use futures::prelude::Future; use tokio::task::JoinSet; +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; use crate::database::{Database, SchemaDatabase}; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceConfig, NamespaceName, NamespaceStore, + Namespace, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, }; use crate::namespace::broadcasters::BroadcasterHandle; +use crate::schema::SchedulerHandle; -use super::ConfigureNamespace; +use super::helpers::{cleanup_primary, make_primary_connection_maker}; +use super::{BaseNamespaceConfig, ConfigureNamespace, PrimaryExtraConfig}; -pub struct SchemaConfigurator; +pub struct SchemaConfigurator { + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler: SchedulerHandle, +} + +impl SchemaConfigurator { + pub fn new(base: BaseNamespaceConfig, primary_config: PrimaryExtraConfig, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, migration_scheduler: SchedulerHandle) -> Self { + Self { base, primary_config, make_wal_manager, migration_scheduler } + } +} impl ConfigureNamespace for SchemaConfigurator { fn setup<'a>( &'a self, - ns_config: &'a NamespaceConfig, db_config: MetaStoreHandle, restore_option: RestoreOption, name: &'a NamespaceName, @@ -29,12 +43,13 @@ impl ConfigureNamespace for SchemaConfigurator { ) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let mut join_set = JoinSet::new(); - let db_path = ns_config.base_path.join("dbs").join(name.as_str()); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); tokio::fs::create_dir_all(&db_path).await?; - let (connection_maker, wal_manager, stats) = Namespace::make_primary_connection_maker( - ns_config, + let (connection_maker, wal_manager, stats) = make_primary_connection_maker( + &self.primary_config, + &self.base, &db_config, &db_path, &name, @@ -43,12 +58,13 @@ impl ConfigureNamespace for SchemaConfigurator { &mut join_set, resolve_attach_path, broadcaster, + self.make_wal_manager.clone() ) .await?; Ok(Namespace { db: Database::Schema(SchemaDatabase::new( - ns_config.migration_scheduler.clone(), + self.migration_scheduler.clone(), name.clone(), connection_maker, wal_manager, @@ -62,4 +78,43 @@ impl ConfigureNamespace for SchemaConfigurator { }) }) } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + db_config: &'a DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: crate::namespace::NamespaceBottomlessDbIdInit, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + cleanup_primary( + &self.base, + &self.primary_config, + namespace, + db_config, + prune_all, + bottomless_db_id_init, + ).await + }) + } + + fn fork<'a>( + &'a self, + from_ns: &'a Namespace, + from_config: MetaStoreHandle, + to_ns: NamespaceName, + to_config: MetaStoreHandle, + timestamp: Option, + store: NamespaceStore, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(super::fork::fork( + from_ns, + from_config, + to_ns, + to_config, + timestamp, + store, + &self.primary_config, + self.base.base_path.clone())) + } } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 5ccda74c54..7cfa6b351c 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -1,52 +1,24 @@ -use std::path::{Path, PathBuf}; -use std::sync::atomic::AtomicBool; -use std::sync::{Arc, Weak}; +use std::path::Path; +use std::sync::Arc; -use anyhow::{Context as _, Error}; -use bottomless::replicator::Options; -use broadcasters::BroadcasterHandle; +use anyhow::Context as _; use bytes::Bytes; use chrono::NaiveDateTime; -use enclose::enclose; use futures_core::{Future, Stream}; -use hyper::Uri; -use libsql_sys::wal::Sqlite3WalManager; -use libsql_sys::EncryptionConfig; -use tokio::io::AsyncBufReadExt; -use tokio::sync::{watch, Semaphore}; use tokio::task::JoinSet; -use tokio::time::Duration; -use tokio_util::io::StreamReader; -use tonic::transport::Channel; use uuid::Uuid; use crate::auth::parse_jwt_keys; use crate::connection::config::DatabaseConfig; -use crate::connection::connection_manager::InnerWalManager; -use crate::connection::libsql::{open_conn, MakeLibSqlConn}; -use crate::connection::{Connection as _, MakeConnection}; -use crate::database::{ - Database, DatabaseKind, PrimaryConnection, PrimaryConnectionMaker, -}; -use crate::error::LoadDumpError; -use crate::replication::script_backup_manager::ScriptBackupManager; -use crate::replication::{FrameNo, ReplicationLogger}; -use crate::schema::SchedulerHandle; +use crate::connection::Connection as _; +use crate::database::Database; use crate::stats::Stats; -use crate::{ - StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, -}; -pub use fork::ForkError; - -use self::fork::{ForkTask, PointInTimeRestore}; use self::meta_store::MetaStoreHandle; pub use self::name::NamespaceName; -use self::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; pub use self::store::NamespaceStore; pub mod broadcasters; -mod fork; pub mod meta_store; mod name; pub mod replication_wal; @@ -101,51 +73,6 @@ impl Namespace { &self.name } - /// completely remove resources associated with the namespace - pub(crate) async fn cleanup( - ns_config: &NamespaceConfig, - name: &NamespaceName, - db_config: &DatabaseConfig, - prune_all: bool, - bottomless_db_id_init: NamespaceBottomlessDbIdInit, - ) -> crate::Result<()> { - let ns_path = ns_config.base_path.join("dbs").join(name.as_str()); - match ns_config.db_kind { - DatabaseKind::Primary => { - if let Some(ref options) = ns_config.bottomless_replication { - let bottomless_db_id = match bottomless_db_id_init { - NamespaceBottomlessDbIdInit::Provided(db_id) => db_id, - NamespaceBottomlessDbIdInit::FetchFromConfig => { - NamespaceBottomlessDbId::from_config(&db_config) - } - }; - let options = make_bottomless_options(options, bottomless_db_id, name.clone()); - let replicator = bottomless::replicator::Replicator::with_options( - ns_path.join("data").to_str().unwrap(), - options, - ) - .await?; - if prune_all { - let delete_all = replicator.delete_all(None).await?; - // perform hard deletion in the background - tokio::spawn(delete_all.commit()); - } else { - // for soft delete make sure that local db is fully backed up - replicator.savepoint().confirmed().await?; - } - } - } - DatabaseKind::Replica => (), - } - - if ns_path.try_exists()? { - tracing::debug!("removing database directory: {}", ns_path.display()); - tokio::fs::remove_dir_all(ns_path).await?; - } - - Ok(()) - } - async fn destroy(mut self) -> anyhow::Result<()> { self.tasks.shutdown().await; self.db.destroy(); @@ -195,293 +122,11 @@ impl Namespace { pub fn config_changed(&self) -> impl Future { self.db_config_store.changed() } - - #[tracing::instrument(skip_all)] - async fn make_primary_connection_maker( - ns_config: &NamespaceConfig, - meta_store_handle: &MetaStoreHandle, - db_path: &Path, - name: &NamespaceName, - restore_option: RestoreOption, - block_writes: Arc, - join_set: &mut JoinSet>, - resolve_attach_path: ResolveNamespacePathFn, - broadcaster: BroadcasterHandle, - ) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { - let db_config = meta_store_handle.get(); - let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); - // FIXME: figure how to to it per-db - let mut is_dirty = { - let sentinel_path = db_path.join(".sentinel"); - if sentinel_path.try_exists()? { - true - } else { - tokio::fs::File::create(&sentinel_path).await?; - false - } - }; - - // FIXME: due to a bug in logger::checkpoint_db we call regular checkpointing code - // instead of our virtual WAL one. It's a bit tangled to fix right now, because - // we need WAL context for checkpointing, and WAL context needs the ReplicationLogger... - // So instead we checkpoint early, *before* bottomless gets initialized. That way - // we're sure bottomless won't try to back up any existing WAL frames and will instead - // treat the existing db file as the source of truth. - - let bottomless_replicator = match ns_config.bottomless_replication { - Some(ref options) => { - tracing::debug!("Checkpointing before initializing bottomless"); - crate::replication::primary::logger::checkpoint_db(&db_path.join("data"))?; - tracing::debug!("Checkpointed before initializing bottomless"); - let options = make_bottomless_options(options, bottomless_db_id, name.clone()); - let (replicator, did_recover) = - init_bottomless_replicator(db_path.join("data"), options, &restore_option) - .await?; - tracing::debug!("Completed init of bottomless replicator"); - is_dirty |= did_recover; - Some(replicator) - } - None => None, - }; - - tracing::debug!("Checking fresh db"); - let is_fresh_db = check_fresh_db(&db_path)?; - // switch frame-count checkpoint to time-based one - let auto_checkpoint = if ns_config.checkpoint_interval.is_some() { - 0 - } else { - DEFAULT_AUTO_CHECKPOINT - }; - - let logger = Arc::new(ReplicationLogger::open( - &db_path, - ns_config.max_log_size, - ns_config.max_log_duration, - is_dirty, - auto_checkpoint, - ns_config.scripted_backup.clone(), - name.clone(), - ns_config.encryption_config.clone(), - )?); - - tracing::debug!("sending stats"); - - let stats = make_stats( - &db_path, - join_set, - meta_store_handle.clone(), - ns_config.stats_sender.clone(), - name.clone(), - logger.new_frame_notifier.subscribe(), - ns_config.encryption_config.clone(), - ) - .await?; - - tracing::debug!("Making replication wal wrapper"); - let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); - - tracing::debug!("Opening libsql connection"); - - let connection_maker = MakeLibSqlConn::new( - db_path.to_path_buf(), - wal_wrapper.clone(), - stats.clone(), - broadcaster, - meta_store_handle.clone(), - ns_config.extensions.clone(), - ns_config.max_response_size, - ns_config.max_total_response_size, - auto_checkpoint, - logger.new_frame_notifier.subscribe(), - ns_config.encryption_config.clone(), - block_writes, - resolve_attach_path, - ns_config.make_wal_manager.clone(), - ) - .await? - .throttled( - ns_config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - ns_config.max_total_response_size, - ns_config.max_concurrent_requests, - ); - - tracing::debug!("Completed opening libsql connection"); - - // this must happen after we create the connection maker. The connection maker old on a - // connection to ensure that no other connection is closing while we try to open the dump. - // that would cause a SQLITE_LOCKED error. - match restore_option { - RestoreOption::Dump(_) if !is_fresh_db => { - Err(LoadDumpError::LoadDumpExistingDb)?; - } - RestoreOption::Dump(dump) => { - let conn = connection_maker.create().await?; - tracing::debug!("Loading dump"); - load_dump(dump, conn).await?; - tracing::debug!("Done loading dump"); - } - _ => { /* other cases were already handled when creating bottomless */ } - } - - join_set.spawn(run_periodic_compactions(logger.clone())); - - tracing::debug!("Done making primary connection"); - - Ok((connection_maker, wal_wrapper, stats)) - } - - async fn fork( - ns_config: &NamespaceConfig, - from_ns: &Namespace, - from_config: MetaStoreHandle, - to_ns: NamespaceName, - to_config: MetaStoreHandle, - timestamp: Option, - store: NamespaceStore, - ) -> crate::Result { - let from_config = from_config.get(); - match ns_config.db_kind { - DatabaseKind::Primary => { - let bottomless_db_id = NamespaceBottomlessDbId::from_config(&from_config); - let restore_to = if let Some(timestamp) = timestamp { - if let Some(ref options) = ns_config.bottomless_replication { - Some(PointInTimeRestore { - timestamp, - replicator_options: make_bottomless_options( - options, - bottomless_db_id.clone(), - from_ns.name().clone(), - ), - }) - } else { - return Err(crate::Error::Fork(ForkError::BackupServiceNotConfigured)); - } - } else { - None - }; - - let logger = match &from_ns.db { - Database::Primary(db) => db.wal_wrapper.wrapper().logger(), - Database::Schema(db) => db.wal_wrapper.wrapper().logger(), - _ => { - return Err(crate::Error::Fork(ForkError::Internal(Error::msg( - "Invalid source database type for fork", - )))); - } - }; - - let fork_task = ForkTask { - base_path: ns_config.base_path.clone(), - to_namespace: to_ns.clone(), - logger, - restore_to, - to_config, - store, - }; - - let ns = fork_task.fork().await?; - Ok(ns) - } - DatabaseKind::Replica => Err(ForkError::ForkReplica.into()), - } - } -} - -pub struct NamespaceConfig { - /// Default database kind the store should be Creating - pub(crate) db_kind: DatabaseKind, - // Common config - pub(crate) base_path: Arc, - pub(crate) max_log_size: u64, - pub(crate) max_log_duration: Option, - pub(crate) extensions: Arc<[PathBuf]>, - pub(crate) stats_sender: StatsSender, - pub(crate) max_response_size: u64, - pub(crate) max_total_response_size: u64, - pub(crate) checkpoint_interval: Option, - pub(crate) max_concurrent_connections: Arc, - pub(crate) max_concurrent_requests: u64, - pub(crate) encryption_config: Option, - - // Replica specific config - /// grpc channel for replica - pub channel: Option, - /// grpc uri - pub uri: Option, - - // primary only config - pub(crate) bottomless_replication: Option, - pub(crate) scripted_backup: Option, - pub(crate) migration_scheduler: SchedulerHandle, - pub(crate) make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, } pub type DumpStream = Box> + Send + Sync + 'static + Unpin>; -fn make_bottomless_options( - options: &Options, - namespace_db_id: NamespaceBottomlessDbId, - name: NamespaceName, -) -> Options { - let mut options = options.clone(); - let mut db_id = match namespace_db_id { - NamespaceBottomlessDbId::Namespace(id) => id, - // FIXME(marin): I don't like that, if bottomless is enabled, proper config must be passed. - NamespaceBottomlessDbId::NotProvided => options.db_id.unwrap_or_default(), - }; - - db_id = format!("ns-{db_id}:{name}"); - options.db_id = Some(db_id); - options -} - -async fn make_stats( - db_path: &Path, - join_set: &mut JoinSet>, - meta_store_handle: MetaStoreHandle, - stats_sender: StatsSender, - name: NamespaceName, - mut current_frame_no: watch::Receiver>, - encryption_config: Option, -) -> anyhow::Result> { - tracing::debug!("creating stats type"); - let stats = Stats::new(name.clone(), db_path, join_set).await?; - - // the storage monitor is optional, so we ignore the error here. - tracing::debug!("stats created, sending stats"); - let _ = stats_sender - .send((name.clone(), meta_store_handle, Arc::downgrade(&stats))) - .await; - - join_set.spawn({ - let stats = stats.clone(); - // initialize the current_frame_no value - current_frame_no - .borrow_and_update() - .map(|fno| stats.set_current_frame_no(fno)); - async move { - while current_frame_no.changed().await.is_ok() { - current_frame_no - .borrow_and_update() - .map(|fno| stats.set_current_frame_no(fno)); - } - Ok(()) - } - }); - - join_set.spawn(run_storage_monitor( - db_path.into(), - Arc::downgrade(&stats), - encryption_config, - )); - - tracing::debug!("done sending stats, and creating bg tasks"); - - Ok(stats) -} - #[derive(Default)] pub enum RestoreOption { /// Restore database state from the most recent version found in a backup. @@ -495,189 +140,3 @@ pub enum RestoreOption { /// Granularity depends of how frequently WAL log pages are being snapshotted. PointInTime(NaiveDateTime), } - -const WASM_TABLE_CREATE: &str = - "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;"; - -async fn load_dump(dump: S, conn: PrimaryConnection) -> crate::Result<(), LoadDumpError> -where - S: Stream> + Unpin, -{ - let mut reader = tokio::io::BufReader::new(StreamReader::new(dump)); - let mut curr = String::new(); - let mut line = String::new(); - let mut skipped_wasm_table = false; - let mut n_stmt = 0; - let mut line_id = 0; - - while let Ok(n) = reader.read_line(&mut curr).await { - line_id += 1; - if n == 0 { - break; - } - let trimmed = curr.trim(); - if trimmed.is_empty() || trimmed.starts_with("--") { - curr.clear(); - continue; - } - // FIXME: it's well known bug that comment ending with semicolon will be handled incorrectly by currend dump processing code - let statement_end = trimmed.ends_with(';'); - - // we want to concat original(non-trimmed) lines as trimming will join all them in one - // single-line statement which is incorrect if comments in the end are present - line.push_str(&curr); - curr.clear(); - - // This is a hack to ignore the libsql_wasm_func_table table because it is already created - // by the system. - if !skipped_wasm_table && line.trim() == WASM_TABLE_CREATE { - skipped_wasm_table = true; - line.clear(); - continue; - } - - if statement_end { - n_stmt += 1; - // dump must be performd within a txn - if n_stmt > 2 && conn.is_autocommit().await.unwrap() { - return Err(LoadDumpError::NoTxn); - } - - line = tokio::task::spawn_blocking({ - let conn = conn.clone(); - move || -> crate::Result { - conn.with_raw(|conn| conn.execute(&line, ())).map_err(|e| { - LoadDumpError::Internal(format!("line: {}, error: {}", line_id, e)) - })?; - Ok(line) - } - }) - .await??; - line.clear(); - } else { - line.push(' '); - } - } - tracing::debug!("loaded {} lines from dump", line_id); - - if !conn.is_autocommit().await.unwrap() { - tokio::task::spawn_blocking({ - let conn = conn.clone(); - move || -> crate::Result<(), LoadDumpError> { - conn.with_raw(|conn| conn.execute("rollback", ()))?; - Ok(()) - } - }) - .await??; - return Err(LoadDumpError::NoCommit); - } - - Ok(()) -} - -pub async fn init_bottomless_replicator( - path: impl AsRef, - options: bottomless::replicator::Options, - restore_option: &RestoreOption, -) -> anyhow::Result<(bottomless::replicator::Replicator, bool)> { - tracing::debug!("Initializing bottomless replication"); - let path = path - .as_ref() - .to_str() - .ok_or_else(|| anyhow::anyhow!("Invalid db path"))? - .to_owned(); - let mut replicator = bottomless::replicator::Replicator::with_options(path, options).await?; - - let (generation, timestamp) = match restore_option { - RestoreOption::Latest | RestoreOption::Dump(_) => (None, None), - RestoreOption::Generation(generation) => (Some(*generation), None), - RestoreOption::PointInTime(timestamp) => (None, Some(*timestamp)), - }; - - let (action, did_recover) = replicator.restore(generation, timestamp).await?; - match action { - bottomless::replicator::RestoreAction::SnapshotMainDbFile => { - replicator.new_generation().await; - if let Some(_handle) = replicator.snapshot_main_db_file(true).await? { - tracing::trace!("got snapshot handle after restore with generation upgrade"); - } - // Restoration process only leaves the local WAL file if it was - // detected to be newer than its remote counterpart. - replicator.maybe_replicate_wal().await? - } - bottomless::replicator::RestoreAction::ReuseGeneration(gen) => { - replicator.set_generation(gen); - } - } - - Ok((replicator, did_recover)) -} - -async fn run_periodic_compactions(logger: Arc) -> anyhow::Result<()> { - // calling `ReplicationLogger::maybe_compact()` is cheap if the compaction does not actually - // take place, so we can afford to poll it very often for simplicity - let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(1000)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - - loop { - interval.tick().await; - let handle = BLOCKING_RT.spawn_blocking(enclose! {(logger) move || { - logger.maybe_compact() - }}); - handle - .await - .expect("Compaction task crashed") - .context("Compaction failed")?; - } -} - -fn check_fresh_db(path: &Path) -> crate::Result { - let is_fresh = !path.join("wallog").try_exists()?; - Ok(is_fresh) -} - -// Periodically check the storage used by the database and save it in the Stats structure. -// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine -// right after checkpointing is exactly where it should be done. -async fn run_storage_monitor( - db_path: PathBuf, - stats: Weak, - encryption_config: Option, -) -> anyhow::Result<()> { - // on initialization, the database file doesn't exist yet, so we wait a bit for it to be - // created - tokio::time::sleep(Duration::from_secs(1)).await; - - let duration = tokio::time::Duration::from_secs(60); - let db_path: Arc = db_path.into(); - loop { - let db_path = db_path.clone(); - let Some(stats) = stats.upgrade() else { - return Ok(()); - }; - - let encryption_config = encryption_config.clone(); - let _ = tokio::task::spawn_blocking(move || { - // because closing the last connection interferes with opening a new one, we lazily - // initialize a connection here, and keep it alive for the entirety of the program. If we - // fail to open it, we wait for `duration` and try again later. - match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), encryption_config) { - Ok(mut conn) => { - if let Ok(tx) = conn.transaction() { - let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); - let freelist_count = tx.query_row("pragma freelist_count;", [], |row| { row.get::(0) }); - if let (Ok(page_count), Ok(freelist_count)) = (page_count, freelist_count) { - let storage_bytes_used = (page_count - freelist_count) * 4096; - stats.set_storage_bytes_used(storage_bytes_used); - } - } - }, - Err(e) => { - tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); - }, - } - }).await; - - tokio::time::sleep(duration).await; - } -} diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index fbce8cd78b..a78e4f59b0 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -20,10 +20,10 @@ use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, Nam use crate::stats::Stats; use super::broadcasters::{BroadcasterHandle, BroadcasterRegistry}; -use super::configurator::NamespaceConfigurators; +use super::configurator::{DynConfigurator, NamespaceConfigurators}; use super::meta_store::{MetaStore, MetaStoreHandle}; use super::schema_lock::SchemaLocksRegistry; -use super::{Namespace, NamespaceConfig, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; +use super::{Namespace, ResetCb, ResetOp, ResolveNamespacePathFn, RestoreOption}; type NamespaceEntry = Arc>>; @@ -46,10 +46,10 @@ pub struct NamespaceStoreInner { allow_lazy_creation: bool, has_shutdown: AtomicBool, snapshot_at_shutdown: bool, - pub config: NamespaceConfig, schema_locks: SchemaLocksRegistry, broadcasters: BroadcasterRegistry, configurators: NamespaceConfigurators, + db_kind: DatabaseKind, } impl NamespaceStore { @@ -57,9 +57,9 @@ impl NamespaceStore { allow_lazy_creation: bool, snapshot_at_shutdown: bool, max_active_namespaces: usize, - config: NamespaceConfig, metadata: MetaStore, configurators: NamespaceConfigurators, + db_kind: DatabaseKind, ) -> crate::Result { tracing::trace!("Max active namespaces: {max_active_namespaces}"); let store = Cache::::builder() @@ -91,10 +91,10 @@ impl NamespaceStore { allow_lazy_creation, has_shutdown: AtomicBool::new(false), snapshot_at_shutdown, - config, schema_locks: Default::default(), broadcasters: Default::default(), configurators, + db_kind, }), }) } @@ -132,14 +132,8 @@ impl NamespaceStore { } } - Namespace::cleanup( - &self.inner.config, - &namespace, - &db_config, - prune_all, - bottomless_db_id_init, - ) - .await?; + self.cleanup(&namespace, &db_config, prune_all, bottomless_db_id_init) + .await?; tracing::info!("destroyed namespace: {namespace}"); @@ -181,15 +175,16 @@ impl NamespaceStore { let db_config = self.inner.metadata.handle(namespace.clone()); // destroy on-disk database - Namespace::cleanup( - &self.inner.config, + self.cleanup( &namespace, &db_config.get(), false, NamespaceBottomlessDbIdInit::FetchFromConfig, ) .await?; - let ns = self.make_namespace(&namespace, db_config, restore_option).await?; + let ns = self + .make_namespace(&namespace, db_config, restore_option) + .await?; lock.replace(ns); @@ -289,16 +284,17 @@ impl NamespaceStore { handle .store_and_maybe_flush(Some(to_config.into()), false) .await?; - let to_ns = Namespace::fork( - &self.inner.config, - from_ns, - from_config, - to.clone(), - handle.clone(), - timestamp, - self.clone(), - ) - .await?; + let to_ns = self + .get_configurator(&from_config.get()) + .fork( + from_ns, + from_config, + to.clone(), + handle.clone(), + timestamp, + self.clone(), + ) + .await?; to_lock.replace(to_ns); handle.flush().await?; @@ -377,23 +373,18 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - let configurator = match self.inner.config.db_kind { - DatabaseKind::Primary if config.get().is_shared_schema => { - self.inner.configurators.configure_schema()? - } - DatabaseKind::Primary => self.inner.configurators.configure_primary()?, - DatabaseKind::Replica => self.inner.configurators.configure_replica()?, - }; - let ns = configurator.setup( - &self.inner.config, - config, - restore_option, - namespace, - self.make_reset_cb(), - self.resolve_attach_fn(), - self.clone(), - self.broadcaster(namespace.clone()), - ).await?; + let ns = self + .get_configurator(&config.get()) + .setup( + config, + restore_option, + namespace, + self.make_reset_cb(), + self.resolve_attach_fn(), + self.clone(), + self.broadcaster(namespace.clone()), + ) + .await?; Ok(ns) } @@ -405,7 +396,9 @@ impl NamespaceStore { restore_option: RestoreOption, ) -> crate::Result { let init = async { - let ns = self.make_namespace(namespace, db_config, restore_option).await?; + let ns = self + .make_namespace(namespace, db_config, restore_option) + .await?; Ok(Some(ns)) }; @@ -521,4 +514,26 @@ impl NamespaceStore { pub(crate) fn schema_locks(&self) -> &SchemaLocksRegistry { &self.inner.schema_locks } + + fn get_configurator(&self, db_config: &DatabaseConfig) -> &DynConfigurator { + match self.inner.db_kind { + DatabaseKind::Primary if db_config.is_shared_schema => { + self.inner.configurators.configure_schema().unwrap() + } + DatabaseKind::Primary => self.inner.configurators.configure_primary().unwrap(), + DatabaseKind::Replica => self.inner.configurators.configure_replica().unwrap(), + } + } + + async fn cleanup( + &self, + namespace: &NamespaceName, + db_config: &DatabaseConfig, + prune_all: bool, + bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> crate::Result<()> { + self.get_configurator(db_config) + .cleanup(namespace, db_config, prune_all, bottomless_db_id_init) + .await + } } diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 17ce655064..a8195cbbd0 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -809,10 +809,11 @@ mod test { use crate::connection::config::DatabaseConfig; use crate::database::DatabaseKind; use crate::namespace::configurator::{ - NamespaceConfigurators, PrimaryConfigurator, SchemaConfigurator, + BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, + SchemaConfigurator, }; use crate::namespace::meta_store::{metastore_connection_maker, MetaStore}; - use crate::namespace::{NamespaceConfig, RestoreOption}; + use crate::namespace::RestoreOption; use crate::schema::SchedulerHandle; use super::super::migration::has_pending_migration_task; @@ -833,9 +834,9 @@ mod test { false, false, 10, - config, meta_store, - NamespaceConfigurators::default(), + config, + DatabaseKind::Primary ) .await .unwrap(); @@ -912,27 +913,41 @@ mod test { assert!(!block_write.load(std::sync::atomic::Ordering::Relaxed)); } - fn make_config(migration_scheduler: SchedulerHandle, path: &Path) -> NamespaceConfig { - NamespaceConfig { - db_kind: DatabaseKind::Primary, + fn make_config(migration_scheduler: SchedulerHandle, path: &Path) -> NamespaceConfigurators { + let mut configurators = NamespaceConfigurators::empty(); + let base_config = BaseNamespaceConfig { base_path: path.to_path_buf().into(), - max_log_size: 1000000000, - max_log_duration: None, extensions: Arc::new([]), stats_sender: tokio::sync::mpsc::channel(1).0, max_response_size: 100000000000000, max_total_response_size: 100000000000, - checkpoint_interval: None, max_concurrent_connections: Arc::new(Semaphore::new(10)), max_concurrent_requests: 10000, - encryption_config: None, - channel: None, - uri: None, + }; + + let primary_config = PrimaryExtraConfig { + max_log_size: 1000000000, + max_log_duration: None, bottomless_replication: None, scripted_backup: None, + checkpoint_interval: None, + }; + + let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + + configurators.with_schema(SchemaConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), migration_scheduler, - make_wal_manager: Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), - } + )); + configurators.with_primary(PrimaryConfigurator::new( + base_config, + primary_config, + make_wal_manager.clone(), + )); + + configurators } #[tokio::test] @@ -950,9 +965,9 @@ mod test { false, false, 10, - config, meta_store, - NamespaceConfigurators::default(), + config, + DatabaseKind::Primary ) .await .unwrap(); @@ -1029,9 +1044,16 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + meta_store, + config, + DatabaseKind::Primary, + ) + .await + .unwrap(); store .with("ns".into(), |ns| { @@ -1056,10 +1078,7 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let configurators = NamespaceConfigurators::default() - .with_schema(SchemaConfigurator) - .with_primary(PrimaryConfigurator); - let store = NamespaceStore::new(false, false, 10, config, meta_store, configurators) + let store = NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) .await .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) @@ -1132,9 +1151,16 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, config, meta_store, NamespaceConfigurators::default()) - .await - .unwrap(); + let store = NamespaceStore::new( + false, + false, + 10, + meta_store, + config, + DatabaseKind::Primary + ) + .await + .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); From 0647711dd81736bcb1fa1f886b3076736becc4a9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 08:30:53 +0200 Subject: [PATCH 09/70] legacy configurators --- libsql-server/src/http/admin/stats.rs | 2 + libsql-server/src/lib.rs | 425 +++++++++++++++----------- libsql-server/src/namespace/store.rs | 13 +- libsql-server/tests/cluster/mod.rs | 29 +- 4 files changed, 279 insertions(+), 190 deletions(-) diff --git a/libsql-server/src/http/admin/stats.rs b/libsql-server/src/http/admin/stats.rs index f2948d4d7b..5fce92ba0a 100644 --- a/libsql-server/src/http/admin/stats.rs +++ b/libsql-server/src/http/admin/stats.rs @@ -140,10 +140,12 @@ pub(super) async fn handle_stats( State(app_state): State>>, Path(namespace): Path, ) -> crate::Result> { + dbg!(); let stats = app_state .namespaces .stats(NamespaceName::from_string(namespace)?) .await?; + dbg!(); let resp: StatsResponse = stats.as_ref().into(); Ok(Json(resp)) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 8bd3ea4fac..4188365e03 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -4,7 +4,6 @@ use std::alloc::Layout; use std::ffi::c_void; use std::mem::{align_of, size_of}; use std::path::{Path, PathBuf}; -use std::pin::Pin; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -29,10 +28,10 @@ use auth::Auth; use config::{ AdminApiConfig, DbConfig, HeartbeatConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, }; -use futures::future::ready; use futures::Future; use http::user::UserApi; use hyper::client::HttpConnector; +use hyper::Uri; use hyper_rustls::HttpsConnector; #[cfg(feature = "durable-wal")] use libsql_storage::{DurableWalManager, LockManager}; @@ -41,10 +40,6 @@ use libsql_sys::wal::either::Either as EitherWAL; #[cfg(feature = "durable-wal")] use libsql_sys::wal::either::Either3 as EitherWAL; use libsql_sys::wal::Sqlite3WalManager; -use libsql_wal::checkpointer::LibsqlCheckpointer; -use libsql_wal::registry::WalRegistry; -use libsql_wal::storage::NoStorage; -use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; use namespace::NamespaceName; use net::Connector; @@ -55,15 +50,19 @@ use tokio::runtime::Runtime; use tokio::sync::{mpsc, Notify, Semaphore}; use tokio::task::JoinSet; use tokio::time::Duration; +use tonic::transport::Channel; use url::Url; use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; -use self::connection::connection_manager::InnerWalManager; -use self::namespace::configurator::{BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator}; +use self::namespace::configurator::{ + BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, + ReplicaConfigurator, SchemaConfigurator, +}; use self::namespace::NamespaceStore; use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; +use self::schema::SchedulerHandle; pub mod auth; mod broadcaster; @@ -424,33 +423,44 @@ where let extensions = self.db_config.validate_extensions()?; let user_auth_strategy = self.user_api_config.auth_strategy.clone(); - let service_shutdown = Arc::new(Notify::new()); - let scripted_backup = match self.db_config.snapshot_exec { Some(ref command) => { let (scripted_backup, script_backup_task) = ScriptBackupManager::new(&self.path, CommandHandler::new(command.to_string())) .await?; - join_set.spawn(script_backup_task.run()); + self.spawn_until_shutdown(&mut join_set, script_backup_task.run()); Some(scripted_backup) } None => None, }; - let (channel, uri) = match self.rpc_client_config { - Some(ref config) => { - let (channel, uri) = config.configure().await?; - (Some(channel), Some(uri)) - } - None => (None, None), + let db_kind = match self.rpc_client_config { + Some(_) => DatabaseKind::Replica, + _ => DatabaseKind::Primary, }; + let client_config = self.get_client_config().await?; let (scheduler_sender, scheduler_receiver) = mpsc::channel(128); - let (stats_sender, stats_receiver) = mpsc::channel(1024); - // chose the wal backend - let (make_wal_manager, registry_shutdown) = self.configure_wal_manager(&mut join_set)?; + let base_config = BaseNamespaceConfig { + base_path: self.path.clone(), + extensions, + stats_sender, + max_response_size: self.db_config.max_response_size, + max_total_response_size: self.db_config.max_total_response_size, + max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), + max_concurrent_requests: self.db_config.max_concurrent_requests, + }; + + let configurators = self + .make_configurators( + base_config, + scripted_backup, + scheduler_sender.into(), + client_config.clone(), + ) + .await?; let (metastore_conn_maker, meta_store_wal_manager) = metastore_connection_maker(self.meta_store_config.bottomless.clone(), &self.path) @@ -464,60 +474,6 @@ where ) .await?; - let base_config = BaseNamespaceConfig { - base_path: self.path.clone(), - extensions, - stats_sender, - max_response_size: self.db_config.max_response_size, - max_total_response_size: self.db_config.max_total_response_size, - max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), - max_concurrent_requests: self.db_config.max_concurrent_requests, - }; - - let mut configurators = NamespaceConfigurators::default(); - - let db_kind = match channel.clone().zip(uri.clone()) { - // replica mode - Some((channel, uri)) => { - let replica_configurator = ReplicaConfigurator::new( - base_config, - channel, - uri, - make_wal_manager, - ); - configurators.with_replica(replica_configurator); - DatabaseKind::Replica - } - // primary mode - None => { - let primary_config = PrimaryExtraConfig { - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - scripted_backup, - checkpoint_interval: self.db_config.checkpoint_interval, - }; - - let primary_configurator = PrimaryConfigurator::new( - base_config.clone(), - primary_config.clone(), - make_wal_manager.clone(), - ); - - let schema_configurator = SchemaConfigurator::new( - base_config.clone(), - primary_config, - make_wal_manager.clone(), - scheduler_sender.into(), - ); - - configurators.with_schema(schema_configurator); - configurators.with_primary(primary_configurator); - - DatabaseKind::Primary - }, - }; - let namespace_store: NamespaceStore = NamespaceStore::new( db_kind.is_replica(), self.db_config.snapshot_at_shutdown, @@ -528,27 +484,9 @@ where ) .await?; - let meta_conn = metastore_conn_maker()?; - let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; - - join_set.spawn(async move { - scheduler.run(scheduler_receiver).await; - Ok(()) - }); self.spawn_monitoring_tasks(&mut join_set, stats_receiver)?; - // eagerly load the default namespace when namespaces are disabled - if self.disable_namespaces && db_kind.is_primary() { - namespace_store - .create( - NamespaceName::default(), - namespace::RestoreOption::Latest, - Default::default(), - ) - .await?; - } - // if namespaces are enabled, then bottomless must have set DB ID if !self.disable_namespaces { if let Some(bottomless) = &self.db_config.bottomless_replication { @@ -563,7 +501,7 @@ where let proxy_service = ProxyService::new(namespace_store.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds - join_set.spawn({ + self.spawn_until_shutdown(&mut join_set, { let clients = proxy_service.clients(); async move { loop { @@ -572,7 +510,8 @@ where } } }); - join_set.spawn(run_rpc_server( + + self.spawn_until_shutdown(&mut join_set, run_rpc_server( proxy_service, config.acceptor, config.tls_config, @@ -584,9 +523,28 @@ where let shutdown_timeout = self.shutdown_timeout.clone(); let shutdown = self.shutdown.clone(); + let service_shutdown = Arc::new(Notify::new()); // setup user-facing rpc services match db_kind { DatabaseKind::Primary => { + // The migration scheduler is only useful on the primary + let meta_conn = metastore_conn_maker()?; + let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; + self.spawn_until_shutdown(&mut join_set, async move { + scheduler.run(scheduler_receiver).await; + Ok(()) + }); + + if self.disable_namespaces { + namespace_store + .create( + NamespaceName::default(), + namespace::RestoreOption::Latest, + Default::default(), + ) + .await?; + } + let replication_svc = ReplicationLogService::new( namespace_store.clone(), idle_shutdown_kicker.clone(), @@ -602,7 +560,7 @@ where ); // Garbage collect proxy clients every 30 seconds - join_set.spawn({ + self.spawn_until_shutdown(&mut join_set, { let clients = proxy_svc.clients(); async move { loop { @@ -623,16 +581,19 @@ where .configure(&mut join_set); } DatabaseKind::Replica => { + dbg!(); + let (channel, uri) = client_config.clone().unwrap(); let replication_svc = - ReplicationLogProxyService::new(channel.clone().unwrap(), uri.clone().unwrap()); + ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( - channel.clone().unwrap(), - uri.clone().unwrap(), + channel, + uri, namespace_store.clone(), user_auth_strategy.clone(), self.disable_namespaces, ); + dbg!(); self.make_services( namespace_store.clone(), idle_shutdown_kicker, @@ -642,6 +603,7 @@ where service_shutdown.clone(), ) .configure(&mut join_set); + dbg!(); } }; @@ -651,7 +613,6 @@ where join_set.shutdown().await; service_shutdown.notify_waiters(); namespace_store.shutdown().await?; - registry_shutdown.await?; Ok::<_, crate::Error>(()) }; @@ -680,100 +641,200 @@ where Ok(()) } - fn setup_shutdown(&self) -> Option { - let shutdown_notify = self.shutdown.clone(); - self.idle_shutdown_timeout.map(|d| { - IdleShutdownKicker::new(d, self.initial_idle_shutdown_timeout, shutdown_notify) - }) - } - - fn configure_wal_manager( + async fn make_configurators( &self, - join_set: &mut JoinSet>, - ) -> anyhow::Result<( - Arc InnerWalManager + Sync + Send + 'static>, - Pin> + Send + Sync + 'static>>, - )> { - let wal_path = self.path.join("wals"); - let enable_libsql_wal_test = { - let is_primary = self.rpc_server_config.is_some(); - let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); - is_primary && is_libsql_wal_test - }; - let use_libsql_wal = - self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; - if !use_libsql_wal { - if wal_path.try_exists()? { - anyhow::bail!("database was previously setup to use libsql-wal"); - } - } - - if self.use_custom_wal.is_some() { - if self.db_config.bottomless_replication.is_some() { - anyhow::bail!("bottomless not supported with custom WAL"); - } - if self.rpc_client_config.is_some() { - anyhow::bail!("custom WAL not supported in replica mode"); + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + match self.use_custom_wal { + Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), + #[cfg(feature = "durable-wal")] + Some(CustomWAL::DurableWal) => self.durable_wal_configurators(), + None => { + self.legacy_configurators( + base_config, + scripted_backup, + migration_scheduler_handle, + client_config, + ) + .await } } + } - let namespace_resolver = |path: &Path| { - NamespaceName::from_string( - path.parent() - .unwrap() - .file_name() - .unwrap() - .to_str() - .unwrap() - .to_string(), - ) - .unwrap() - .into() - }; - - match self.use_custom_wal { - Some(CustomWAL::LibsqlWal) => { - let (sender, receiver) = tokio::sync::mpsc::channel(64); - let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); - let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); - join_set.spawn(async move { - checkpointer.run().await; - Ok(()) - }); + fn libsql_wal_configurators(&self) -> anyhow::Result { + todo!() + } - let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); - let shutdown_notify = self.shutdown.clone(); - let shutdown_fut = Box::pin(async move { - shutdown_notify.notified().await; - registry.shutdown().await?; - Ok(()) - }); + #[cfg(feature = "durable-wal")] + fn durable_wal_configurators(&self) -> anyhow::Result { + todo!(); + } - tracing::info!("using libsql wal"); - Ok((Arc::new(move || EitherWAL::B(wal.clone())), shutdown_fut)) + fn spawn_until_shutdown(&self, join_set: &mut JoinSet>, fut: F) + where + F: Future> + Send + 'static, + { + let shutdown = self.shutdown.clone(); + join_set.spawn(async move { + tokio::select! { + _ = shutdown.notified() => Ok(()), + ret = fut => ret } - #[cfg(feature = "durable-wal")] - Some(CustomWAL::DurableWal) => { - tracing::info!("using durable wal"); - let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); - let wal = DurableWalManager::new( - lock_manager, - namespace_resolver, - self.storage_server_address.clone(), - ); - Ok(( - Arc::new(move || EitherWAL::C(wal.clone())), - Box::pin(ready(Ok(()))), - )) + }); + } + + async fn legacy_configurators( + &self, + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + let mut configurators = NamespaceConfigurators::empty(); + + match client_config { + // replica mode + Some((channel, uri)) => { + let replica_configurator = + ReplicaConfigurator::new(base_config, channel, uri, make_wal_manager); + configurators.with_replica(replica_configurator); } + // primary mode None => { - tracing::info!("using sqlite3 wal"); - Ok(( - Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), - Box::pin(ready(Ok(()))), - )) + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; + + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); + + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + migration_scheduler_handle, + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); } } + + Ok(configurators) + } + + fn setup_shutdown(&self) -> Option { + let shutdown_notify = self.shutdown.clone(); + self.idle_shutdown_timeout.map(|d| { + IdleShutdownKicker::new(d, self.initial_idle_shutdown_timeout, shutdown_notify) + }) + } + + // fn configure_wal_manager( + // &self, + // join_set: &mut JoinSet>, + // ) -> anyhow::Result<( + // Arc InnerWalManager + Sync + Send + 'static>, + // Pin> + Send + Sync + 'static>>, + // )> { + // let wal_path = self.path.join("wals"); + // let enable_libsql_wal_test = { + // let is_primary = self.rpc_server_config.is_some(); + // let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); + // is_primary && is_libsql_wal_test + // }; + // let use_libsql_wal = + // self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; + // if !use_libsql_wal { + // if wal_path.try_exists()? { + // anyhow::bail!("database was previously setup to use libsql-wal"); + // } + // } + // + // if self.use_custom_wal.is_some() { + // if self.db_config.bottomless_replication.is_some() { + // anyhow::bail!("bottomless not supported with custom WAL"); + // } + // if self.rpc_client_config.is_some() { + // anyhow::bail!("custom WAL not supported in replica mode"); + // } + // } + // + // let namespace_resolver = |path: &Path| { + // NamespaceName::from_string( + // path.parent() + // .unwrap() + // .file_name() + // .unwrap() + // .to_str() + // .unwrap() + // .to_string(), + // ) + // .unwrap() + // .into() + // }; + // + // match self.use_custom_wal { + // Some(CustomWAL::LibsqlWal) => { + // let (sender, receiver) = tokio::sync::mpsc::channel(64); + // let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); + // let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); + // join_set.spawn(async move { + // checkpointer.run().await; + // Ok(()) + // }); + // + // let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); + // let shutdown_notify = self.shutdown.clone(); + // let shutdown_fut = Box::pin(async move { + // shutdown_notify.notified().await; + // registry.shutdown().await?; + // Ok(()) + // }); + // + // tracing::info!("using libsql wal"); + // Ok((Arc::new(move || EitherWAL::B(wal.clone())), shutdown_fut)) + // } + // #[cfg(feature = "durable-wal")] + // Some(CustomWAL::DurableWal) => { + // tracing::info!("using durable wal"); + // let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); + // let wal = DurableWalManager::new( + // lock_manager, + // namespace_resolver, + // self.storage_server_address.clone(), + // ); + // Ok(( + // Arc::new(move || EitherWAL::C(wal.clone())), + // Box::pin(ready(Ok(()))), + // )) + // } + // None => { + // tracing::info!("using sqlite3 wal"); + // Ok(( + // Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())), + // Box::pin(ready(Ok(()))), + // )) + // } + // } + // } + + async fn get_client_config(&self) -> anyhow::Result> { + match self.rpc_client_config { + Some(ref config) => Ok(Some(config.configure().await?)), + None => Ok(None), + } } } diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index a78e4f59b0..b2b5d33032 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -327,6 +327,7 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { + dbg!(); if namespace != NamespaceName::default() && !self.inner.metadata.exists(&namespace) && !self.inner.allow_lazy_creation @@ -334,6 +335,7 @@ impl NamespaceStore { return Err(Error::NamespaceDoesntExist(namespace.to_string())); } + dbg!(); let f = { let name = namespace.clone(); move |ns: NamespaceEntry| async move { @@ -346,7 +348,9 @@ impl NamespaceStore { } }; + dbg!(); let handle = self.inner.metadata.handle(namespace.to_owned()); + dbg!(); f(self .load_namespace(&namespace, handle, RestoreOption::Latest) .await?) @@ -373,6 +377,7 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { + dbg!(); let ns = self .get_configurator(&config.get()) .setup( @@ -386,6 +391,7 @@ impl NamespaceStore { ) .await?; + dbg!(); Ok(ns) } @@ -395,13 +401,17 @@ impl NamespaceStore { db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { + dbg!(); let init = async { + dbg!(); let ns = self .make_namespace(namespace, db_config, restore_option) .await?; + dbg!(); Ok(Some(ns)) }; + dbg!(); let before_load = Instant::now(); let ns = self .inner @@ -410,7 +420,8 @@ impl NamespaceStore { namespace.clone(), init.map_ok(|ns| Arc::new(RwLock::new(ns))), ) - .await?; + .await.map_err(|e| dbg!(e))?; + dbg!(); NAMESPACE_LOAD_LATENCY.record(before_load.elapsed()); Ok(ns) diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 1171d4a5d0..8f214bd05e 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -149,23 +149,29 @@ fn sync_many_replica() { let mut sim = Builder::new() .simulation_duration(Duration::from_secs(1000)) .build(); + dbg!(); make_cluster(&mut sim, NUM_REPLICA, true); + dbg!(); sim.client("client", async { let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; + dbg!(); conn.execute("create table test (x)", ()).await?; + dbg!(); conn.execute("insert into test values (42)", ()).await?; + dbg!(); async fn get_frame_no(url: &str) -> Option { let client = Client::new(); + dbg!(); Some( - client - .get(url) - .await - .unwrap() - .json::() - .await + dbg!(client + .get(url) + .await + .unwrap() + .json::() + .await) .unwrap() .get("replication_index")? .as_u64() @@ -173,6 +179,7 @@ fn sync_many_replica() { ) } + dbg!(); let primary_fno = loop { if let Some(fno) = get_frame_no("http://primary:9090/v1/namespaces/default/stats").await { @@ -180,13 +187,15 @@ fn sync_many_replica() { } }; + dbg!(); // wait for all replicas to sync let mut join_set = JoinSet::new(); for i in 0..NUM_REPLICA { join_set.spawn(async move { let uri = format!("http://replica{i}:9090/v1/namespaces/default/stats"); + dbg!(); loop { - if let Some(replica_fno) = get_frame_no(&uri).await { + if let Some(replica_fno) = dbg!(get_frame_no(&uri).await) { if replica_fno == primary_fno { break; } @@ -196,8 +205,10 @@ fn sync_many_replica() { }); } + dbg!(); while join_set.join_next().await.is_some() {} + dbg!(); for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), @@ -212,8 +223,10 @@ fn sync_many_replica() { )); } + dbg!(); let client = Client::new(); + dbg!(); let stats = client .get("http://primary:9090/v1/namespaces/default/stats") .await? @@ -221,12 +234,14 @@ fn sync_many_replica() { .await .unwrap(); + dbg!(); let stat = stats .get("embedded_replica_frames_replicated") .unwrap() .as_u64() .unwrap(); + dbg!(); assert_eq!(stat, 0); Ok(()) From 76558e8df4d1927e7bd9fee7da4821f1a07a87f6 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 13:33:23 +0400 Subject: [PATCH 10/70] fix behaviour of VACUUM for vector indices to make rowid consistent between shadow tables and base table --- libsql-sqlite3/src/vacuum.c | 26 ++++++++++++++++++ libsql-sqlite3/src/vectorIndex.c | 28 +------------------- libsql-sqlite3/test/libsql_vector_index.test | 9 +++++-- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/libsql-sqlite3/src/vacuum.c b/libsql-sqlite3/src/vacuum.c index c0ae4bc1e1..d927a8d5a6 100644 --- a/libsql-sqlite3/src/vacuum.c +++ b/libsql-sqlite3/src/vacuum.c @@ -17,6 +17,10 @@ #include "sqliteInt.h" #include "vdbeInt.h" +#ifndef SQLITE_OMIT_VECTOR +#include "vectorIndexInt.h" +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -294,6 +298,27 @@ SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -305,6 +330,7 @@ SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 78266ed462..d520419a41 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -49,11 +49,6 @@ ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -772,10 +767,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -786,10 +777,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -799,7 +786,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -817,10 +804,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -970,7 +953,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -1055,10 +1037,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -1078,10 +1056,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 19d31ba19c..7308b2d93f 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -236,12 +236,17 @@ do_execsql_test vector-attach { do_execsql_test vector-vacuum { CREATE TABLE t_vacuum ( emb FLOAT32(2) ); - INSERT INTO t_vacuum VALUES (vector('[1,2]')), (vector('[3,4]')); + INSERT INTO t_vacuum VALUES (vector('[1,2]')), (vector('[3,4]')), (vector('[5,6]')); CREATE INDEX t_vacuum_idx ON t_vacuum(libsql_vector_idx(emb)); VACUUM; SELECT COUNT(*) FROM t_vacuum; SELECT COUNT(*) FROM t_vacuum_idx_shadow; -} {2 2} + DELETE FROM t_vacuum WHERE rowid = 2; + VACUUM; + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[1,2]'), 3); + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[5,6]'), 3); + SELECT * FROM vector_top_k('t_vacuum_idx', vector('[3,4]'), 3); +} {3 3 1 2 2 1 2 1} do_execsql_test vector-many-columns { CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) ); From 853143d77b94b26653a87d45134e2b77476096b3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 13:48:21 +0400 Subject: [PATCH 11/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 55 ++++++++++--------- libsql-ffi/bundled/bindings/bindgen.rs | 26 ++++++++- libsql-ffi/bundled/src/sqlite3.c | 55 ++++++++++--------- 3 files changed, 79 insertions(+), 57 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index c22f35046f..ec692baa53 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -69,6 +69,7 @@ ** src/test2.c ** src/test3.c ** src/test8.c +** src/vacuum.c ** src/vdbe.c ** src/vdbeInt.h ** src/vdbeapi.c @@ -155950,6 +155951,10 @@ SQLITE_PRIVATE void sqlite3UpsertDoUpdate( /* #include "sqliteInt.h" */ /* #include "vdbeInt.h" */ +#ifndef SQLITE_OMIT_VECTOR +/* #include "vectorIndexInt.h" */ +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -156227,6 +156232,27 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -156238,6 +156264,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; @@ -213656,11 +213683,6 @@ int vectorF64ParseSqliteBlob( ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -214379,10 +214401,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -214393,10 +214411,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -214406,7 +214420,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -214424,10 +214438,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -214577,7 +214587,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -214662,10 +214671,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -214685,10 +214690,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; diff --git a/libsql-ffi/bundled/bindings/bindgen.rs b/libsql-ffi/bundled/bindings/bindgen.rs index e11d453281..cc73807f33 100644 --- a/libsql-ffi/bundled/bindings/bindgen.rs +++ b/libsql-ffi/bundled/bindings/bindgen.rs @@ -24,10 +24,10 @@ extern "C" { } pub const __GNUC_VA_LIST: i32 = 1; -pub const SQLITE_VERSION: &[u8; 7] = b"3.44.0\0"; -pub const SQLITE_VERSION_NUMBER: i32 = 3044000; +pub const SQLITE_VERSION: &[u8; 7] = b"3.45.1\0"; +pub const SQLITE_VERSION_NUMBER: i32 = 3045001; pub const SQLITE_SOURCE_ID: &[u8; 85] = - b"2023-11-01 11:23:50 17129ba1ff7f0daf37100ee82d507aef7827cf38de1866e2633096ae6ad8alt1\0"; + b"2024-01-30 16:01:20 e876e51a0ed5c5b3126f52e532044363a014bc594cfefa87ffb5b82257ccalt1\0"; pub const LIBSQL_VERSION: &[u8; 6] = b"0.2.3\0"; pub const SQLITE_OK: i32 = 0; pub const SQLITE_ERROR: i32 = 1; @@ -356,6 +356,7 @@ pub const SQLITE_DETERMINISTIC: i32 = 2048; pub const SQLITE_DIRECTONLY: i32 = 524288; pub const SQLITE_SUBTYPE: i32 = 1048576; pub const SQLITE_INNOCUOUS: i32 = 2097152; +pub const SQLITE_RESULT_SUBTYPE: i32 = 16777216; pub const SQLITE_WIN32_DATA_DIRECTORY_TYPE: i32 = 1; pub const SQLITE_WIN32_TEMP_DIRECTORY_TYPE: i32 = 2; pub const SQLITE_TXN_NONE: i32 = 0; @@ -408,6 +409,7 @@ pub const SQLITE_TESTCTRL_PENDING_BYTE: i32 = 11; pub const SQLITE_TESTCTRL_ASSERT: i32 = 12; pub const SQLITE_TESTCTRL_ALWAYS: i32 = 13; pub const SQLITE_TESTCTRL_RESERVE: i32 = 14; +pub const SQLITE_TESTCTRL_JSON_SELFCHECK: i32 = 14; pub const SQLITE_TESTCTRL_OPTIMIZATIONS: i32 = 15; pub const SQLITE_TESTCTRL_ISKEYWORD: i32 = 16; pub const SQLITE_TESTCTRL_SCRATCHMALLOC: i32 = 17; @@ -3133,6 +3135,24 @@ pub struct Fts5ExtensionApi { piCol: *mut ::std::os::raw::c_int, ), >, + pub xQueryToken: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut Fts5Context, + iPhrase: ::std::os::raw::c_int, + iToken: ::std::os::raw::c_int, + ppToken: *mut *const ::std::os::raw::c_char, + pnToken: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int, + >, + pub xInstToken: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut Fts5Context, + iIdx: ::std::os::raw::c_int, + iToken: ::std::os::raw::c_int, + arg2: *mut *const ::std::os::raw::c_char, + arg3: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int, + >, } #[repr(C)] #[derive(Debug, Copy, Clone)] diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index c22f35046f..ec692baa53 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -69,6 +69,7 @@ ** src/test2.c ** src/test3.c ** src/test8.c +** src/vacuum.c ** src/vdbe.c ** src/vdbeInt.h ** src/vdbeapi.c @@ -155950,6 +155951,10 @@ SQLITE_PRIVATE void sqlite3UpsertDoUpdate( /* #include "sqliteInt.h" */ /* #include "vdbeInt.h" */ +#ifndef SQLITE_OMIT_VECTOR +/* #include "vectorIndexInt.h" */ +#endif + #if !defined(SQLITE_OMIT_VACUUM) && !defined(SQLITE_OMIT_ATTACH) /* @@ -156227,6 +156232,27 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( if( rc!=SQLITE_OK ) goto end_of_vacuum; db->init.iDb = 0; +#ifndef SQLITE_OMIT_VECTOR + // shadow tables for vector index will be populated automatically during CREATE INDEX command + // so we must skip them at this step + if( sqlite3FindTable(db, VECTOR_INDEX_GLOBAL_META_TABLE, zDbMain) != NULL ){ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name NOT IN (SELECT name||'_shadow' FROM " VECTOR_INDEX_GLOBAL_META_TABLE ")", + zDbMain + ); + }else{ + rc = execSqlF(db, pzErrMsg, + "SELECT'INSERT INTO vacuum_db.'||quote(name)" + "||' SELECT*FROM\"%w\".'||quote(name)" + "FROM vacuum_db.sqlite_schema " + "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + zDbMain + ); + } +#else /* Loop through the tables in the main database. For each, do ** an "INSERT INTO vacuum_db.xxx SELECT * FROM main.xxx;" to copy ** the contents to the temporary database. @@ -156238,6 +156264,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); +#endif assert( (db->mDbFlags & DBFLAG_Vacuum)!=0 ); db->mDbFlags &= ~DBFLAG_Vacuum; if( rc!=SQLITE_OK ) goto end_of_vacuum; @@ -213656,11 +213683,6 @@ int vectorF64ParseSqliteBlob( ** VectorIdxParams utilities ****************************************************************************/ -// VACUUM creates tables and indices first and only then populate data -// we need to ignore inserts from 'INSERT INTO vacuum.t SELECT * FROM t' statements because -// all shadow tables will be populated by VACUUM process during regular process of table copy -#define IsVacuum(db) ((db->mDbFlags&DBFLAG_Vacuum)!=0) - void vectorIdxParamsInit(VectorIdxParams *pParams, u8 *pBinBuf, int nBinSize) { assert( nBinSize <= VECTOR_INDEX_PARAMS_BUF_SIZE ); @@ -214379,10 +214401,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { // this is done to prevent unrecoverable situations where index were dropped but index parameters deletion failed and second attempt will fail on first step int rcIdx, rcParams; - if( IsVacuum(db) ){ - return SQLITE_OK; - } - assert( zDbSName != NULL ); rcIdx = diskAnnDropIndex(db, zDbSName, zIdxName); @@ -214393,10 +214411,6 @@ int vectorIndexDrop(sqlite3 *db, const char *zDbSName, const char *zIdxName) { int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { assert( zDbSName != NULL ); - if( IsVacuum(db) ){ - return SQLITE_OK; - } - return diskAnnClearIndex(db, zDbSName, zIdxName); } @@ -214406,7 +214420,7 @@ int vectorIndexClear(sqlite3 *db, const char *zDbSName, const char *zIdxName) { * this made intentionally in order to natively support upload of SQLite dumps * * dump populates tables first and create indices after - * so we must omit them because shadow tables already filled + * so we must omit index refill setp because shadow tables already filled * * 1. in case of any error :-1 returned (and pParse errMsg is populated with some error message) * 2. if vector index must not be created : 0 returned @@ -214424,10 +214438,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int hasLibsqlVectorIdxFn = 0, hasCollation = 0; const char *pzErrMsg; - if( IsVacuum(pParse->db) ){ - return CREATE_IGNORE; - } - assert( zDbSName != NULL ); sqlite3 *db = pParse->db; @@ -214577,7 +214587,6 @@ int vectorIndexSearch( VectorIdxParams idxParams; vectorIdxParamsInit(&idxParams, NULL, 0); - assert( !IsVacuum(db) ); assert( zDbSName != NULL ); if( argc != 3 ){ @@ -214662,10 +214671,6 @@ int vectorIndexInsert( int rc; VectorInRow vectorInRow; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - rc = vectorInRowAlloc(pCur->db, pRecord, &vectorInRow, pzErrMsg); if( rc != SQLITE_OK ){ return rc; @@ -214685,10 +214690,6 @@ int vectorIndexDelete( ){ VectorInRow payload; - if( IsVacuum(pCur->db) ){ - return SQLITE_OK; - } - payload.pVector = NULL; payload.nKeys = r->nField - 1; payload.pKeyValues = r->aMem + 1; From 2115277f8c6c8ab8494dc70fa1413d30b7f362f4 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 14:51:32 +0400 Subject: [PATCH 12/70] fix bug --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 2 +- libsql-ffi/bundled/src/sqlite3.c | 2 +- libsql-sqlite3/src/vacuum.c | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index ec692baa53..c25985af8a 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -156248,7 +156248,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index ec692baa53..c25985af8a 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -156248,7 +156248,7 @@ SQLITE_PRIVATE SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } diff --git a/libsql-sqlite3/src/vacuum.c b/libsql-sqlite3/src/vacuum.c index d927a8d5a6..f8e848aca6 100644 --- a/libsql-sqlite3/src/vacuum.c +++ b/libsql-sqlite3/src/vacuum.c @@ -314,7 +314,7 @@ SQLITE_NOINLINE int sqlite3RunVacuum( "SELECT'INSERT INTO vacuum_db.'||quote(name)" "||' SELECT*FROM\"%w\".'||quote(name)" "FROM vacuum_db.sqlite_schema " - "WHERE type='table'AND coalesce(rootpage,1)>0 AND name", + "WHERE type='table'AND coalesce(rootpage,1)>0", zDbMain ); } From b12431c33a172ba0c96f35cfd35e36ce931b5c00 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 13:30:21 +0200 Subject: [PATCH 13/70] configure durable wal --- libsql-server/src/lib.rs | 98 ++++++++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 4188365e03..9ee8e3b908 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -55,6 +55,7 @@ use url::Url; use utils::services::idle_shutdown::IdleShutdownKicker; use self::config::MetaStoreConfig; +use self::connection::connection_manager::InnerWalManager; use self::namespace::configurator::{ BaseNamespaceConfig, NamespaceConfigurators, PrimaryConfigurator, PrimaryExtraConfig, ReplicaConfigurator, SchemaConfigurator, @@ -336,7 +337,8 @@ where config.heartbeat_url.as_deref().unwrap_or(""), config.heartbeat_period, ); - join_set.spawn({ + + self.spawn_until_shutdown_on(join_set, { let heartbeat_auth = config.heartbeat_auth.clone(); let heartbeat_period = config.heartbeat_period; let heartbeat_url = if let Some(url) = &config.heartbeat_url { @@ -428,7 +430,7 @@ where let (scripted_backup, script_backup_task) = ScriptBackupManager::new(&self.path, CommandHandler::new(command.to_string())) .await?; - self.spawn_until_shutdown(&mut join_set, script_backup_task.run()); + self.spawn_until_shutdown_on(&mut join_set, script_backup_task.run()); Some(scripted_backup) } None => None, @@ -484,7 +486,6 @@ where ) .await?; - self.spawn_monitoring_tasks(&mut join_set, stats_receiver)?; // if namespaces are enabled, then bottomless must have set DB ID @@ -501,7 +502,7 @@ where let proxy_service = ProxyService::new(namespace_store.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds - self.spawn_until_shutdown(&mut join_set, { + self.spawn_until_shutdown_on(&mut join_set, { let clients = proxy_service.clients(); async move { loop { @@ -511,14 +512,17 @@ where } }); - self.spawn_until_shutdown(&mut join_set, run_rpc_server( - proxy_service, - config.acceptor, - config.tls_config, - idle_shutdown_kicker.clone(), - namespace_store.clone(), - self.disable_namespaces, - )); + self.spawn_until_shutdown_on( + &mut join_set, + run_rpc_server( + proxy_service, + config.acceptor, + config.tls_config, + idle_shutdown_kicker.clone(), + namespace_store.clone(), + self.disable_namespaces, + ), + ); } let shutdown_timeout = self.shutdown_timeout.clone(); @@ -530,7 +534,7 @@ where // The migration scheduler is only useful on the primary let meta_conn = metastore_conn_maker()?; let scheduler = Scheduler::new(namespace_store.clone(), meta_conn).await?; - self.spawn_until_shutdown(&mut join_set, async move { + self.spawn_until_shutdown_on(&mut join_set, async move { scheduler.run(scheduler_receiver).await; Ok(()) }); @@ -560,7 +564,7 @@ where ); // Garbage collect proxy clients every 30 seconds - self.spawn_until_shutdown(&mut join_set, { + self.spawn_until_shutdown_on(&mut join_set, { let clients = proxy_svc.clients(); async move { loop { @@ -583,8 +587,7 @@ where DatabaseKind::Replica => { dbg!(); let (channel, uri) = client_config.clone().unwrap(); - let replication_svc = - ReplicationLogProxyService::new(channel.clone(), uri.clone()); + let replication_svc = ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( channel, uri, @@ -651,7 +654,12 @@ where match self.use_custom_wal { Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), #[cfg(feature = "durable-wal")] - Some(CustomWAL::DurableWal) => self.durable_wal_configurators(), + Some(CustomWAL::DurableWal) => self.durable_wal_configurators( + base_config, + scripted_backup, + migration_scheduler_handle, + client_config, + ), None => { self.legacy_configurators( base_config, @@ -669,11 +677,44 @@ where } #[cfg(feature = "durable-wal")] - fn durable_wal_configurators(&self) -> anyhow::Result { - todo!(); + fn durable_wal_configurators( + &self, + base_config: BaseNamespaceConfig, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + client_config: Option<(Channel, Uri)>, + ) -> anyhow::Result { + tracing::info!("using durable wal"); + let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); + let namespace_resolver = |path: &Path| { + NamespaceName::from_string( + path.parent() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + .to_string(), + ) + .unwrap() + .into() + }; + let wal = DurableWalManager::new( + lock_manager, + namespace_resolver, + self.storage_server_address.clone(), + ); + let make_wal_manager = Arc::new(move || EitherWAL::C(wal.clone())); + self.configurators_common( + client_config, + base_config, + make_wal_manager, + scripted_backup, + migration_scheduler_handle, + ) } - fn spawn_until_shutdown(&self, join_set: &mut JoinSet>, fut: F) + fn spawn_until_shutdown_on(&self, join_set: &mut JoinSet>, fut: F) where F: Future> + Send + 'static, { @@ -694,6 +735,23 @@ where client_config: Option<(Channel, Uri)>, ) -> anyhow::Result { let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); + self.configurators_common( + client_config, + base_config, + make_wal_manager, + scripted_backup, + migration_scheduler_handle, + ) + } + + fn configurators_common( + &self, + client_config: Option<(Channel, Uri)>, + base_config: BaseNamespaceConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + scripted_backup: Option, + migration_scheduler_handle: SchedulerHandle, + ) -> anyhow::Result { let mut configurators = NamespaceConfigurators::empty(); match client_config { From 066f1527572d3e3012457e588e9d28e8e74656ec Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 14:37:10 +0200 Subject: [PATCH 14/70] configure libsql_wal --- libsql-server/src/lib.rs | 213 ++++++++++++++++++++++++++++++--------- 1 file changed, 165 insertions(+), 48 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 9ee8e3b908..f5788dcebb 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -28,6 +28,7 @@ use auth::Auth; use config::{ AdminApiConfig, DbConfig, HeartbeatConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, }; +use futures::future::{pending, ready}; use futures::Future; use http::user::UserApi; use hyper::client::HttpConnector; @@ -40,6 +41,10 @@ use libsql_sys::wal::either::Either as EitherWAL; #[cfg(feature = "durable-wal")] use libsql_sys::wal::either::Either3 as EitherWAL; use libsql_sys::wal::Sqlite3WalManager; +use libsql_wal::checkpointer::LibsqlCheckpointer; +use libsql_wal::registry::WalRegistry; +use libsql_wal::storage::NoStorage; +use libsql_wal::wal::LibsqlWalManager; use namespace::meta_store::MetaStoreHandle; use namespace::NamespaceName; use net::Connector; @@ -458,9 +463,10 @@ where let configurators = self .make_configurators( base_config, - scripted_backup, - scheduler_sender.into(), client_config.clone(), + &mut join_set, + scheduler_sender.into(), + scripted_backup, ) .await?; @@ -596,7 +602,6 @@ where self.disable_namespaces, ); - dbg!(); self.make_services( namespace_store.clone(), idle_shutdown_kicker, @@ -647,42 +652,125 @@ where async fn make_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + join_set: &mut JoinSet>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { + let wal_path = base_config.base_path.join("wals"); + let enable_libsql_wal_test = { + let is_primary = self.rpc_server_config.is_some(); + let is_libsql_wal_test = std::env::var("LIBSQL_WAL_TEST").is_ok(); + is_primary && is_libsql_wal_test + }; + let use_libsql_wal = + self.use_custom_wal == Some(CustomWAL::LibsqlWal) || enable_libsql_wal_test; + if !use_libsql_wal { + if wal_path.try_exists()? { + anyhow::bail!("database was previously setup to use libsql-wal"); + } + } + + if self.use_custom_wal.is_some() { + if self.db_config.bottomless_replication.is_some() { + anyhow::bail!("bottomless not supported with custom WAL"); + } + if self.rpc_client_config.is_some() { + anyhow::bail!("custom WAL not supported in replica mode"); + } + } + match self.use_custom_wal { - Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators(), + Some(CustomWAL::LibsqlWal) => self.libsql_wal_configurators( + base_config, + client_config, + join_set, + migration_scheduler_handle, + scripted_backup, + wal_path, + ), #[cfg(feature = "durable-wal")] Some(CustomWAL::DurableWal) => self.durable_wal_configurators( base_config, - scripted_backup, - migration_scheduler_handle, client_config, + migration_scheduler_handle, + scripted_backup, ), None => { self.legacy_configurators( base_config, - scripted_backup, - migration_scheduler_handle, client_config, + migration_scheduler_handle, + scripted_backup, ) .await } } } - fn libsql_wal_configurators(&self) -> anyhow::Result { - todo!() + fn libsql_wal_configurators( + &self, + base_config: BaseNamespaceConfig, + client_config: Option<(Channel, Uri)>, + join_set: &mut JoinSet>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, + wal_path: PathBuf, + ) -> anyhow::Result { + tracing::info!("using libsql wal"); + let (sender, receiver) = tokio::sync::mpsc::channel(64); + let registry = Arc::new(WalRegistry::new(wal_path, NoStorage, sender)?); + let checkpointer = LibsqlCheckpointer::new(registry.clone(), receiver, 8); + self.spawn_until_shutdown_on(join_set, async move { + checkpointer.run().await; + Ok(()) + }); + + let namespace_resolver = |path: &Path| { + NamespaceName::from_string( + path.parent() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + .to_string(), + ) + .unwrap() + .into() + }; + let wal = LibsqlWalManager::new(registry.clone(), Arc::new(namespace_resolver)); + + self.spawn_until_shutdown_with_teardown(join_set, pending(), async move { + registry.shutdown().await?; + Ok(()) + }); + + let make_wal_manager = Arc::new(move || EitherWAL::B(wal.clone())); + let mut configurators = NamespaceConfigurators::empty(); + + match client_config { + Some(_) => todo!("configure replica"), + // configure primary + None => self.configure_primary_common( + base_config, + &mut configurators, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ), + } + + Ok(configurators) } #[cfg(feature = "durable-wal")] fn durable_wal_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { tracing::info!("using durable wal"); let lock_manager = Arc::new(std::sync::Mutex::new(LockManager::new())); @@ -706,22 +794,37 @@ where ); let make_wal_manager = Arc::new(move || EitherWAL::C(wal.clone())); self.configurators_common( - client_config, base_config, + client_config, make_wal_manager, - scripted_backup, migration_scheduler_handle, + scripted_backup, ) } fn spawn_until_shutdown_on(&self, join_set: &mut JoinSet>, fut: F) where F: Future> + Send + 'static, + { + self.spawn_until_shutdown_with_teardown(join_set, fut, ready(Ok(()))) + } + + /// run the passed future until shutdown is called, then call the passed teardown future + fn spawn_until_shutdown_with_teardown( + &self, + join_set: &mut JoinSet>, + fut: F, + teardown: T, + ) where + F: Future> + Send + 'static, + T: Future> + Send + 'static, { let shutdown = self.shutdown.clone(); join_set.spawn(async move { tokio::select! { - _ = shutdown.notified() => Ok(()), + _ = shutdown.notified() => { + teardown.await + }, ret = fut => ret } }); @@ -730,30 +833,29 @@ where async fn legacy_configurators( &self, base_config: BaseNamespaceConfig, - scripted_backup: Option, - migration_scheduler_handle: SchedulerHandle, client_config: Option<(Channel, Uri)>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { let make_wal_manager = Arc::new(|| EitherWAL::A(Sqlite3WalManager::default())); self.configurators_common( - client_config, base_config, + client_config, make_wal_manager, - scripted_backup, migration_scheduler_handle, + scripted_backup, ) } fn configurators_common( &self, - client_config: Option<(Channel, Uri)>, base_config: BaseNamespaceConfig, + client_config: Option<(Channel, Uri)>, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, - scripted_backup: Option, migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, ) -> anyhow::Result { let mut configurators = NamespaceConfigurators::empty(); - match client_config { // replica mode Some((channel, uri)) => { @@ -762,34 +864,49 @@ where configurators.with_replica(replica_configurator); } // primary mode - None => { - let primary_config = PrimaryExtraConfig { - max_log_size: self.db_config.max_log_size, - max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), - bottomless_replication: self.db_config.bottomless_replication.clone(), - scripted_backup, - checkpoint_interval: self.db_config.checkpoint_interval, - }; + None => self.configure_primary_common( + base_config, + &mut configurators, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ), + } - let primary_configurator = PrimaryConfigurator::new( - base_config.clone(), - primary_config.clone(), - make_wal_manager.clone(), - ); + Ok(configurators) + } - let schema_configurator = SchemaConfigurator::new( - base_config.clone(), - primary_config, - make_wal_manager.clone(), - migration_scheduler_handle, - ); + fn configure_primary_common( + &self, + base_config: BaseNamespaceConfig, + configurators: &mut NamespaceConfigurators, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler_handle: SchedulerHandle, + scripted_backup: Option, + ) { + let primary_config = PrimaryExtraConfig { + max_log_size: self.db_config.max_log_size, + max_log_duration: self.db_config.max_log_duration.map(Duration::from_secs_f32), + bottomless_replication: self.db_config.bottomless_replication.clone(), + scripted_backup, + checkpoint_interval: self.db_config.checkpoint_interval, + }; - configurators.with_schema(schema_configurator); - configurators.with_primary(primary_configurator); - } - } + let primary_configurator = PrimaryConfigurator::new( + base_config.clone(), + primary_config.clone(), + make_wal_manager.clone(), + ); - Ok(configurators) + let schema_configurator = SchemaConfigurator::new( + base_config.clone(), + primary_config, + make_wal_manager.clone(), + migration_scheduler_handle, + ); + + configurators.with_schema(schema_configurator); + configurators.with_primary(primary_configurator); } fn setup_shutdown(&self) -> Option { From b5dba7241531d3c5bb58a093189b72b9ea47fb25 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 16:52:23 +0200 Subject: [PATCH 15/70] partial implmentation of LibsqlWalReplicationConfigurator --- .../configurator/libsql_wal_replica.rs | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 libsql-server/src/namespace/configurator/libsql_wal_replica.rs diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs new file mode 100644 index 0000000000..6b2519cf33 --- /dev/null +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -0,0 +1,138 @@ +use std::pin::Pin; +use std::future::Future; +use std::sync::Arc; + +use chrono::prelude::NaiveDateTime; +use hyper::Uri; +use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; +use libsql_wal::io::StdIO; +use libsql_wal::registry::WalRegistry; +use libsql_wal::storage::NoStorage; +use tokio::task::JoinSet; +use tonic::transport::Channel; + +use crate::connection::config::DatabaseConfig; +use crate::connection::connection_manager::InnerWalManager; +use crate::connection::write_proxy::MakeWriteProxyConn; +use crate::connection::MakeConnection; +use crate::database::{Database, ReplicaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::configurator::helpers::make_stats; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, +}; +use crate::DEFAULT_AUTO_CHECKPOINT; + +use super::{BaseNamespaceConfig, ConfigureNamespace}; + +pub struct LibsqlWalReplicaConfigurator { + base: BaseNamespaceConfig, + registry: Arc>, + uri: Uri, + channel: Channel, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, +} + +impl ConfigureNamespace for LibsqlWalReplicaConfigurator { + fn setup<'a>( + &'a self, + db_config: MetaStoreHandle, + restore_option: RestoreOption, + name: &'a NamespaceName, + reset: ResetCb, + resolve_attach_path: ResolveNamespacePathFn, + store: NamespaceStore, + broadcaster: BroadcasterHandle, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + tracing::debug!("creating replica namespace"); + let db_path = self.base.base_path.join("dbs").join(name.as_str()); + let channel = self.channel.clone(); + let uri = self.uri.clone(); + + let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + // TODO! setup replication + + let mut join_set = JoinSet::new(); + let namespace = name.clone(); + + let stats = make_stats( + &db_path, + &mut join_set, + db_config.clone(), + self.base.stats_sender.clone(), + name.clone(), + applied_frame_no_receiver.clone(), + ) + .await?; + + let connection_maker = MakeWriteProxyConn::new( + db_path.clone(), + self.base.extensions.clone(), + channel.clone(), + uri.clone(), + stats.clone(), + broadcaster, + db_config.clone(), + applied_frame_no_receiver, + self.base.max_response_size, + self.base.max_total_response_size, + primary_current_replication_index, + None, + resolve_attach_path, + self.make_wal_manager.clone(), + ) + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); + + Ok(Namespace { + tasks: join_set, + db: Database::Replica(ReplicaDatabase { + connection_maker: Arc::new(connection_maker), + }), + name: name.clone(), + stats, + db_config_store: db_config, + path: db_path.into(), + }) + }) + } + + fn cleanup<'a>( + &'a self, + namespace: &'a NamespaceName, + _db_config: &DatabaseConfig, + _prune_all: bool, + _bottomless_db_id_init: NamespaceBottomlessDbIdInit, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let ns_path = self.base.base_path.join("dbs").join(namespace.as_str()); + if ns_path.try_exists()? { + tracing::debug!("removing database directory: {}", ns_path.display()); + tokio::fs::remove_dir_all(ns_path).await?; + } + Ok(()) + }) + } + + fn fork<'a>( + &'a self, + _from_ns: &'a Namespace, + _from_config: MetaStoreHandle, + _to_ns: NamespaceName, + _to_config: MetaStoreHandle, + _timestamp: Option, + _store: NamespaceStore, + ) -> Pin> + Send + 'a>> { + Box::pin(std::future::ready(Err(crate::Error::Fork( + super::fork::ForkError::ForkReplica, + )))) + } +} From ded5ba7f859f6f6b94f1a1b6614e31521a591f01 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 16:54:14 +0200 Subject: [PATCH 16/70] fmt + remove dbgs --- libsql-server/src/http/admin/stats.rs | 2 - libsql-server/src/lib.rs | 2 - .../src/namespace/configurator/fork.rs | 7 ++- .../src/namespace/configurator/helpers.rs | 60 +++++++++--------- .../configurator/libsql_wal_replica.rs | 18 +++--- .../src/namespace/configurator/mod.rs | 18 ++++-- .../src/namespace/configurator/primary.rs | 13 ++-- .../src/namespace/configurator/schema.rs | 29 ++++++--- libsql-server/src/namespace/mod.rs | 2 +- libsql-server/src/namespace/store.rs | 13 +--- libsql-server/src/schema/scheduler.rs | 63 ++++++------------- libsql-server/tests/cluster/mod.rs | 29 +++------ 12 files changed, 111 insertions(+), 145 deletions(-) diff --git a/libsql-server/src/http/admin/stats.rs b/libsql-server/src/http/admin/stats.rs index 5fce92ba0a..f2948d4d7b 100644 --- a/libsql-server/src/http/admin/stats.rs +++ b/libsql-server/src/http/admin/stats.rs @@ -140,12 +140,10 @@ pub(super) async fn handle_stats( State(app_state): State>>, Path(namespace): Path, ) -> crate::Result> { - dbg!(); let stats = app_state .namespaces .stats(NamespaceName::from_string(namespace)?) .await?; - dbg!(); let resp: StatsResponse = stats.as_ref().into(); Ok(Json(resp)) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index f5788dcebb..d26921dd00 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -591,7 +591,6 @@ where .configure(&mut join_set); } DatabaseKind::Replica => { - dbg!(); let (channel, uri) = client_config.clone().unwrap(); let replication_svc = ReplicationLogProxyService::new(channel.clone(), uri.clone()); let proxy_svc = ReplicaProxyService::new( @@ -611,7 +610,6 @@ where service_shutdown.clone(), ) .configure(&mut join_set); - dbg!(); } }; diff --git a/libsql-server/src/namespace/configurator/fork.rs b/libsql-server/src/namespace/configurator/fork.rs index 26a0b99b61..03f2ac03d8 100644 --- a/libsql-server/src/namespace/configurator/fork.rs +++ b/libsql-server/src/namespace/configurator/fork.rs @@ -58,7 +58,7 @@ pub(super) async fn fork( Database::Schema(db) => db.wal_wrapper.wrapper().logger(), _ => { return Err(crate::Error::Fork(ForkError::Internal(anyhow::Error::msg( - "Invalid source database type for fork", + "Invalid source database type for fork", )))); } }; @@ -114,7 +114,7 @@ pub struct ForkTask { pub to_namespace: NamespaceName, pub to_config: MetaStoreHandle, pub restore_to: Option, - pub store: NamespaceStore + pub store: NamespaceStore, } pub struct PointInTimeRestore { @@ -156,7 +156,8 @@ impl ForkTask { let dest_path = self.base_path.join("dbs").join(self.to_namespace.as_str()); tokio::fs::rename(temp_dir.path(), dest_path).await?; - self.store.make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) + self.store + .make_namespace(&self.to_namespace, self.to_config, RestoreOption::Latest) .await .map_err(|e| ForkError::CreateNamespace(Box::new(e))) } diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs index f43fa8a192..a5a4c5121d 100644 --- a/libsql-server/src/namespace/configurator/helpers.rs +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -6,26 +6,29 @@ use std::time::Duration; use anyhow::Context as _; use bottomless::replicator::Options; use bytes::Bytes; +use enclose::enclose; use futures::Stream; use libsql_sys::wal::Sqlite3WalManager; use tokio::io::AsyncBufReadExt as _; use tokio::sync::watch; use tokio::task::JoinSet; use tokio_util::io::StreamReader; -use enclose::enclose; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::connection::libsql::{open_conn, MakeLibSqlConn}; use crate::connection::{Connection as _, MakeConnection as _}; +use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; use crate::error::LoadDumpError; +use crate::namespace::broadcasters::BroadcasterHandle; +use crate::namespace::meta_store::MetaStoreHandle; +use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; +use crate::namespace::{ + NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, + RestoreOption, +}; use crate::replication::{FrameNo, ReplicationLogger}; use crate::stats::Stats; -use crate::namespace::{NamespaceBottomlessDbId, NamespaceBottomlessDbIdInit, NamespaceName, ResolveNamespacePathFn, RestoreOption}; -use crate::namespace::replication_wal::{make_replication_wal_wrapper, ReplicationWalWrapper}; -use crate::namespace::meta_store::MetaStoreHandle; -use crate::namespace::broadcasters::BroadcasterHandle; -use crate::database::{PrimaryConnection, PrimaryConnectionMaker}; use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT}; use super::{BaseNamespaceConfig, PrimaryExtraConfig}; @@ -74,8 +77,7 @@ pub(super) async fn make_primary_connection_maker( tracing::debug!("Checkpointed before initializing bottomless"); let options = make_bottomless_options(options, bottomless_db_id, name.clone()); let (replicator, did_recover) = - init_bottomless_replicator(db_path.join("data"), options, &restore_option) - .await?; + init_bottomless_replicator(db_path.join("data"), options, &restore_option).await?; tracing::debug!("Completed init of bottomless replicator"); is_dirty |= did_recover; Some(replicator) @@ -93,14 +95,14 @@ pub(super) async fn make_primary_connection_maker( }; let logger = Arc::new(ReplicationLogger::open( - &db_path, - primary_config.max_log_size, - primary_config.max_log_duration, - is_dirty, - auto_checkpoint, - primary_config.scripted_backup.clone(), - name.clone(), - None, + &db_path, + primary_config.max_log_size, + primary_config.max_log_duration, + is_dirty, + auto_checkpoint, + primary_config.scripted_backup.clone(), + name.clone(), + None, )?); tracing::debug!("sending stats"); @@ -113,7 +115,7 @@ pub(super) async fn make_primary_connection_maker( name.clone(), logger.new_frame_notifier.subscribe(), ) - .await?; + .await?; tracing::debug!("Making replication wal wrapper"); let wal_wrapper = make_replication_wal_wrapper(bottomless_replicator, logger.clone()); @@ -136,13 +138,13 @@ pub(super) async fn make_primary_connection_maker( resolve_attach_path, make_wal_manager.clone(), ) - .await? - .throttled( - base_config.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - base_config.max_total_response_size, - base_config.max_concurrent_requests, - ); + .await? + .throttled( + base_config.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + base_config.max_total_response_size, + base_config.max_concurrent_requests, + ); tracing::debug!("Completed opening libsql connection"); @@ -356,10 +358,7 @@ pub(super) async fn make_stats( } }); - join_set.spawn(run_storage_monitor( - db_path.into(), - Arc::downgrade(&stats), - )); + join_set.spawn(run_storage_monitor(db_path.into(), Arc::downgrade(&stats))); tracing::debug!("done sending stats, and creating bg tasks"); @@ -369,10 +368,7 @@ pub(super) async fn make_stats( // Periodically check the storage used by the database and save it in the Stats structure. // TODO: Once we have a separate fiber that does WAL checkpoints, running this routine // right after checkpointing is exactly where it should be done. -async fn run_storage_monitor( - db_path: PathBuf, - stats: Weak, -) -> anyhow::Result<()> { +async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Result<()> { // on initialization, the database file doesn't exist yet, so we wait a bit for it to be // created tokio::time::sleep(Duration::from_secs(1)).await; diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs index 6b2519cf33..f26738ec2a 100644 --- a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -1,5 +1,5 @@ -use std::pin::Pin; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use chrono::prelude::NaiveDateTime; @@ -66,7 +66,7 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { name.clone(), applied_frame_no_receiver.clone(), ) - .await?; + .await?; let connection_maker = MakeWriteProxyConn::new( db_path.clone(), @@ -84,13 +84,13 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { resolve_attach_path, self.make_wal_manager.clone(), ) - .await? - .throttled( - self.base.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - self.base.max_total_response_size, - self.base.max_concurrent_requests, - ); + .await? + .throttled( + self.base.max_concurrent_connections.clone(), + Some(DB_CREATE_TIMEOUT), + self.base.max_total_response_size, + self.base.max_concurrent_requests, + ); Ok(Namespace { tasks: join_set, diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index e5db335ff6..9122fc18de 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -13,13 +13,17 @@ use crate::StatsSender; use super::broadcasters::BroadcasterHandle; use super::meta_store::MetaStoreHandle; -use super::{Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption}; +use super::{ + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, +}; +pub mod fork; mod helpers; +mod libsql_wal_replica; mod primary; mod replica; mod schema; -pub mod fork; pub use primary::PrimaryConfigurator; pub use replica::ReplicaConfigurator; @@ -68,12 +72,18 @@ impl NamespaceConfigurators { } } - pub fn with_primary(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_primary( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { self.primary_configurator = Some(Box::new(c)); self } - pub fn with_replica(&mut self, c: impl ConfigureNamespace + Send + Sync + 'static) -> &mut Self { + pub fn with_replica( + &mut self, + c: impl ConfigureNamespace + Send + Sync + 'static, + ) -> &mut Self { self.replica_configurator = Some(Box::new(c)); self } diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index 4351f6a3ac..6c245a6e8f 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -12,8 +12,8 @@ use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::configurator::helpers::make_primary_connection_maker; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, - ResetCb, ResolveNamespacePathFn, RestoreOption, + Namespace, NamespaceBottomlessDbIdInit, NamespaceName, NamespaceStore, ResetCb, + ResolveNamespacePathFn, RestoreOption, }; use crate::run_periodic_checkpoint; use crate::schema::{has_pending_migration_task, setup_migration_table}; @@ -168,7 +168,8 @@ impl ConfigureNamespace for PrimaryConfigurator { db_config, prune_all, bottomless_db_id_init, - ).await + ) + .await }) } @@ -186,10 +187,10 @@ impl ConfigureNamespace for PrimaryConfigurator { from_config, to_ns, to_config, - timestamp, + timestamp, store, &self.primary_config, - self.base.base_path.clone())) + self.base.base_path.clone(), + )) } } - diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index e55c706fec..98e679513a 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -6,12 +6,11 @@ use tokio::task::JoinSet; use crate::connection::config::DatabaseConfig; use crate::connection::connection_manager::InnerWalManager; use crate::database::{Database, SchemaDatabase}; +use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::{ - Namespace, NamespaceName, NamespaceStore, - ResetCb, ResolveNamespacePathFn, RestoreOption, + Namespace, NamespaceName, NamespaceStore, ResetCb, ResolveNamespacePathFn, RestoreOption, }; -use crate::namespace::broadcasters::BroadcasterHandle; use crate::schema::SchedulerHandle; use super::helpers::{cleanup_primary, make_primary_connection_maker}; @@ -25,8 +24,18 @@ pub struct SchemaConfigurator { } impl SchemaConfigurator { - pub fn new(base: BaseNamespaceConfig, primary_config: PrimaryExtraConfig, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, migration_scheduler: SchedulerHandle) -> Self { - Self { base, primary_config, make_wal_manager, migration_scheduler } + pub fn new( + base: BaseNamespaceConfig, + primary_config: PrimaryExtraConfig, + make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + migration_scheduler: SchedulerHandle, + ) -> Self { + Self { + base, + primary_config, + make_wal_manager, + migration_scheduler, + } } } @@ -58,7 +67,7 @@ impl ConfigureNamespace for SchemaConfigurator { &mut join_set, resolve_attach_path, broadcaster, - self.make_wal_manager.clone() + self.make_wal_manager.clone(), ) .await?; @@ -94,7 +103,8 @@ impl ConfigureNamespace for SchemaConfigurator { db_config, prune_all, bottomless_db_id_init, - ).await + ) + .await }) } @@ -112,9 +122,10 @@ impl ConfigureNamespace for SchemaConfigurator { from_config, to_ns, to_config, - timestamp, + timestamp, store, &self.primary_config, - self.base.base_path.clone())) + self.base.base_path.clone(), + )) } } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 7cfa6b351c..2a2e3eb211 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -19,12 +19,12 @@ pub use self::name::NamespaceName; pub use self::store::NamespaceStore; pub mod broadcasters; +pub(crate) mod configurator; pub mod meta_store; mod name; pub mod replication_wal; mod schema_lock; mod store; -pub(crate) mod configurator; pub type ResetCb = Box; pub type ResolveNamespacePathFn = diff --git a/libsql-server/src/namespace/store.rs b/libsql-server/src/namespace/store.rs index b2b5d33032..a78e4f59b0 100644 --- a/libsql-server/src/namespace/store.rs +++ b/libsql-server/src/namespace/store.rs @@ -327,7 +327,6 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { - dbg!(); if namespace != NamespaceName::default() && !self.inner.metadata.exists(&namespace) && !self.inner.allow_lazy_creation @@ -335,7 +334,6 @@ impl NamespaceStore { return Err(Error::NamespaceDoesntExist(namespace.to_string())); } - dbg!(); let f = { let name = namespace.clone(); move |ns: NamespaceEntry| async move { @@ -348,9 +346,7 @@ impl NamespaceStore { } }; - dbg!(); let handle = self.inner.metadata.handle(namespace.to_owned()); - dbg!(); f(self .load_namespace(&namespace, handle, RestoreOption::Latest) .await?) @@ -377,7 +373,6 @@ impl NamespaceStore { config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - dbg!(); let ns = self .get_configurator(&config.get()) .setup( @@ -391,7 +386,6 @@ impl NamespaceStore { ) .await?; - dbg!(); Ok(ns) } @@ -401,17 +395,13 @@ impl NamespaceStore { db_config: MetaStoreHandle, restore_option: RestoreOption, ) -> crate::Result { - dbg!(); let init = async { - dbg!(); let ns = self .make_namespace(namespace, db_config, restore_option) .await?; - dbg!(); Ok(Some(ns)) }; - dbg!(); let before_load = Instant::now(); let ns = self .inner @@ -420,8 +410,7 @@ impl NamespaceStore { namespace.clone(), init.map_ok(|ns| Arc::new(RwLock::new(ns))), ) - .await.map_err(|e| dbg!(e))?; - dbg!(); + .await?; NAMESPACE_LOAD_LATENCY.record(before_load.elapsed()); Ok(ns) diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index a8195cbbd0..57916bb9a5 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -830,16 +830,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -961,16 +955,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1044,16 +1032,10 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary, - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); store .with("ns".into(), |ns| { @@ -1078,9 +1060,10 @@ mod test { .unwrap(); let (sender, mut receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let mut scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); @@ -1151,16 +1134,10 @@ mod test { .unwrap(); let (sender, _receiver) = mpsc::channel(100); let config = make_config(sender.clone().into(), tmp.path()); - let store = NamespaceStore::new( - false, - false, - 10, - meta_store, - config, - DatabaseKind::Primary - ) - .await - .unwrap(); + let store = + NamespaceStore::new(false, false, 10, meta_store, config, DatabaseKind::Primary) + .await + .unwrap(); let scheduler = Scheduler::new(store.clone(), maker().unwrap()) .await .unwrap(); diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 8f214bd05e..1171d4a5d0 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -149,29 +149,23 @@ fn sync_many_replica() { let mut sim = Builder::new() .simulation_duration(Duration::from_secs(1000)) .build(); - dbg!(); make_cluster(&mut sim, NUM_REPLICA, true); - dbg!(); sim.client("client", async { let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; - dbg!(); conn.execute("create table test (x)", ()).await?; - dbg!(); conn.execute("insert into test values (42)", ()).await?; - dbg!(); async fn get_frame_no(url: &str) -> Option { let client = Client::new(); - dbg!(); Some( - dbg!(client - .get(url) - .await - .unwrap() - .json::() - .await) + client + .get(url) + .await + .unwrap() + .json::() + .await .unwrap() .get("replication_index")? .as_u64() @@ -179,7 +173,6 @@ fn sync_many_replica() { ) } - dbg!(); let primary_fno = loop { if let Some(fno) = get_frame_no("http://primary:9090/v1/namespaces/default/stats").await { @@ -187,15 +180,13 @@ fn sync_many_replica() { } }; - dbg!(); // wait for all replicas to sync let mut join_set = JoinSet::new(); for i in 0..NUM_REPLICA { join_set.spawn(async move { let uri = format!("http://replica{i}:9090/v1/namespaces/default/stats"); - dbg!(); loop { - if let Some(replica_fno) = dbg!(get_frame_no(&uri).await) { + if let Some(replica_fno) = get_frame_no(&uri).await { if replica_fno == primary_fno { break; } @@ -205,10 +196,8 @@ fn sync_many_replica() { }); } - dbg!(); while join_set.join_next().await.is_some() {} - dbg!(); for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), @@ -223,10 +212,8 @@ fn sync_many_replica() { )); } - dbg!(); let client = Client::new(); - dbg!(); let stats = client .get("http://primary:9090/v1/namespaces/default/stats") .await? @@ -234,14 +221,12 @@ fn sync_many_replica() { .await .unwrap(); - dbg!(); let stat = stats .get("embedded_replica_frames_replicated") .unwrap() .as_u64() .unwrap(); - dbg!(); assert_eq!(stat, 0); Ok(()) From e5b8c31005069982de27ad0b58b109929f45bf4b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 17:01:34 +0200 Subject: [PATCH 17/70] comment out libsql-wal replica configurator --- libsql-server/src/lib.rs | 34 +++--- .../configurator/libsql_wal_replica.rs | 115 +++++++++--------- .../src/namespace/configurator/mod.rs | 2 +- 3 files changed, 79 insertions(+), 72 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index d26921dd00..9bf0419932 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -745,21 +745,27 @@ where }); let make_wal_manager = Arc::new(move || EitherWAL::B(wal.clone())); - let mut configurators = NamespaceConfigurators::empty(); + // let mut configurators = NamespaceConfigurators::empty(); + + // match client_config { + // Some(_) => todo!("configure replica"), + // // configure primary + // None => self.configure_primary_common( + // base_config, + // &mut configurators, + // make_wal_manager, + // migration_scheduler_handle, + // scripted_backup, + // ), + // } - match client_config { - Some(_) => todo!("configure replica"), - // configure primary - None => self.configure_primary_common( - base_config, - &mut configurators, - make_wal_manager, - migration_scheduler_handle, - scripted_backup, - ), - } - - Ok(configurators) + self.configurators_common( + base_config, + client_config, + make_wal_manager, + migration_scheduler_handle, + scripted_backup, + ) } #[cfg(feature = "durable-wal")] diff --git a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs index f26738ec2a..6ab6cc52ef 100644 --- a/libsql-server/src/namespace/configurator/libsql_wal_replica.rs +++ b/libsql-server/src/namespace/configurator/libsql_wal_replica.rs @@ -46,63 +46,64 @@ impl ConfigureNamespace for LibsqlWalReplicaConfigurator { store: NamespaceStore, broadcaster: BroadcasterHandle, ) -> Pin> + Send + 'a>> { - Box::pin(async move { - tracing::debug!("creating replica namespace"); - let db_path = self.base.base_path.join("dbs").join(name.as_str()); - let channel = self.channel.clone(); - let uri = self.uri.clone(); - - let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); - // TODO! setup replication - - let mut join_set = JoinSet::new(); - let namespace = name.clone(); - - let stats = make_stats( - &db_path, - &mut join_set, - db_config.clone(), - self.base.stats_sender.clone(), - name.clone(), - applied_frame_no_receiver.clone(), - ) - .await?; - - let connection_maker = MakeWriteProxyConn::new( - db_path.clone(), - self.base.extensions.clone(), - channel.clone(), - uri.clone(), - stats.clone(), - broadcaster, - db_config.clone(), - applied_frame_no_receiver, - self.base.max_response_size, - self.base.max_total_response_size, - primary_current_replication_index, - None, - resolve_attach_path, - self.make_wal_manager.clone(), - ) - .await? - .throttled( - self.base.max_concurrent_connections.clone(), - Some(DB_CREATE_TIMEOUT), - self.base.max_total_response_size, - self.base.max_concurrent_requests, - ); - - Ok(Namespace { - tasks: join_set, - db: Database::Replica(ReplicaDatabase { - connection_maker: Arc::new(connection_maker), - }), - name: name.clone(), - stats, - db_config_store: db_config, - path: db_path.into(), - }) - }) + todo!() + // Box::pin(async move { + // tracing::debug!("creating replica namespace"); + // let db_path = self.base.base_path.join("dbs").join(name.as_str()); + // let channel = self.channel.clone(); + // let uri = self.uri.clone(); + // + // let rpc_client = ReplicationLogClient::with_origin(channel.clone(), uri.clone()); + // // TODO! setup replication + // + // let mut join_set = JoinSet::new(); + // let namespace = name.clone(); + // + // let stats = make_stats( + // &db_path, + // &mut join_set, + // db_config.clone(), + // self.base.stats_sender.clone(), + // name.clone(), + // applied_frame_no_receiver.clone(), + // ) + // .await?; + // + // let connection_maker = MakeWriteProxyConn::new( + // db_path.clone(), + // self.base.extensions.clone(), + // channel.clone(), + // uri.clone(), + // stats.clone(), + // broadcaster, + // db_config.clone(), + // applied_frame_no_receiver, + // self.base.max_response_size, + // self.base.max_total_response_size, + // primary_current_replication_index, + // None, + // resolve_attach_path, + // self.make_wal_manager.clone(), + // ) + // .await? + // .throttled( + // self.base.max_concurrent_connections.clone(), + // Some(DB_CREATE_TIMEOUT), + // self.base.max_total_response_size, + // self.base.max_concurrent_requests, + // ); + // + // Ok(Namespace { + // tasks: join_set, + // db: Database::Replica(ReplicaDatabase { + // connection_maker: Arc::new(connection_maker), + // }), + // name: name.clone(), + // stats, + // db_config_store: db_config, + // path: db_path.into(), + // }) + // }) } fn cleanup<'a>( diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 9122fc18de..0f8dcbd481 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -20,7 +20,7 @@ use super::{ pub mod fork; mod helpers; -mod libsql_wal_replica; +// mod libsql_wal_replica; mod primary; mod replica; mod schema; From 6e7fb9f06a901fe20a340cf9e54b1523c44f8eb5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 6 Aug 2024 18:24:20 +0200 Subject: [PATCH 18/70] restore encryption config we don't actually care, but let's do it for completeness --- libsql-server/src/lib.rs | 1 + .../src/namespace/configurator/helpers.rs | 23 +++++++++++++++---- .../src/namespace/configurator/mod.rs | 2 ++ .../src/namespace/configurator/primary.rs | 8 ++++++- .../src/namespace/configurator/replica.rs | 1 + .../src/namespace/configurator/schema.rs | 1 + libsql-server/src/schema/scheduler.rs | 1 + 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 9bf0419932..4b97b442f5 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -458,6 +458,7 @@ where max_total_response_size: self.db_config.max_total_response_size, max_concurrent_connections: Arc::new(Semaphore::new(self.max_concurrent_connections)), max_concurrent_requests: self.db_config.max_concurrent_requests, + encryption_config: self.db_config.encryption_config.clone(), }; let configurators = self diff --git a/libsql-server/src/namespace/configurator/helpers.rs b/libsql-server/src/namespace/configurator/helpers.rs index a5a4c5121d..355b1b1472 100644 --- a/libsql-server/src/namespace/configurator/helpers.rs +++ b/libsql-server/src/namespace/configurator/helpers.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use enclose::enclose; use futures::Stream; use libsql_sys::wal::Sqlite3WalManager; +use libsql_sys::EncryptionConfig; use tokio::io::AsyncBufReadExt as _; use tokio::sync::watch; use tokio::task::JoinSet; @@ -49,6 +50,7 @@ pub(super) async fn make_primary_connection_maker( resolve_attach_path: ResolveNamespacePathFn, broadcaster: BroadcasterHandle, make_wal_manager: Arc InnerWalManager + Sync + Send + 'static>, + encryption_config: Option, ) -> crate::Result<(PrimaryConnectionMaker, ReplicationWalWrapper, Arc)> { let db_config = meta_store_handle.get(); let bottomless_db_id = NamespaceBottomlessDbId::from_config(&db_config); @@ -102,7 +104,7 @@ pub(super) async fn make_primary_connection_maker( auto_checkpoint, primary_config.scripted_backup.clone(), name.clone(), - None, + encryption_config.clone(), )?); tracing::debug!("sending stats"); @@ -114,6 +116,7 @@ pub(super) async fn make_primary_connection_maker( base_config.stats_sender.clone(), name.clone(), logger.new_frame_notifier.subscribe(), + base_config.encryption_config.clone(), ) .await?; @@ -133,7 +136,7 @@ pub(super) async fn make_primary_connection_maker( base_config.max_total_response_size, auto_checkpoint, logger.new_frame_notifier.subscribe(), - None, + encryption_config, block_writes, resolve_attach_path, make_wal_manager.clone(), @@ -332,6 +335,7 @@ pub(super) async fn make_stats( stats_sender: StatsSender, name: NamespaceName, mut current_frame_no: watch::Receiver>, + encryption_config: Option, ) -> anyhow::Result> { tracing::debug!("creating stats type"); let stats = Stats::new(name.clone(), db_path, join_set).await?; @@ -358,7 +362,11 @@ pub(super) async fn make_stats( } }); - join_set.spawn(run_storage_monitor(db_path.into(), Arc::downgrade(&stats))); + join_set.spawn(run_storage_monitor( + db_path.into(), + Arc::downgrade(&stats), + encryption_config, + )); tracing::debug!("done sending stats, and creating bg tasks"); @@ -368,7 +376,11 @@ pub(super) async fn make_stats( // Periodically check the storage used by the database and save it in the Stats structure. // TODO: Once we have a separate fiber that does WAL checkpoints, running this routine // right after checkpointing is exactly where it should be done. -async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Result<()> { +async fn run_storage_monitor( + db_path: PathBuf, + stats: Weak, + encryption_config: Option, +) -> anyhow::Result<()> { // on initialization, the database file doesn't exist yet, so we wait a bit for it to be // created tokio::time::sleep(Duration::from_secs(1)).await; @@ -381,11 +393,12 @@ async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Re return Ok(()); }; + let encryption_config = encryption_config.clone(); let _ = tokio::task::spawn_blocking(move || { // because closing the last connection interferes with opening a new one, we lazily // initialize a connection here, and keep it alive for the entirety of the program. If we // fail to open it, we wait for `duration` and try again later. - match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), None) { + match open_conn(&db_path, Sqlite3WalManager::new(), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), encryption_config) { Ok(mut conn) => { if let Ok(tx) = conn.transaction() { let page_count = tx.query_row("pragma page_count;", [], |row| { row.get::(0) }); diff --git a/libsql-server/src/namespace/configurator/mod.rs b/libsql-server/src/namespace/configurator/mod.rs index 0f8dcbd481..b96d5a3824 100644 --- a/libsql-server/src/namespace/configurator/mod.rs +++ b/libsql-server/src/namespace/configurator/mod.rs @@ -5,6 +5,7 @@ use std::time::Duration; use chrono::NaiveDateTime; use futures::Future; +use libsql_sys::EncryptionConfig; use tokio::sync::Semaphore; use crate::connection::config::DatabaseConfig; @@ -38,6 +39,7 @@ pub struct BaseNamespaceConfig { pub(crate) max_total_response_size: u64, pub(crate) max_concurrent_connections: Arc, pub(crate) max_concurrent_requests: u64, + pub(crate) encryption_config: Option, } #[derive(Clone)] diff --git a/libsql-server/src/namespace/configurator/primary.rs b/libsql-server/src/namespace/configurator/primary.rs index 6c245a6e8f..03cdd2fd7b 100644 --- a/libsql-server/src/namespace/configurator/primary.rs +++ b/libsql-server/src/namespace/configurator/primary.rs @@ -1,7 +1,10 @@ +use std::path::Path; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::{path::Path, pin::Pin, sync::Arc}; +use std::sync::Arc; use futures::prelude::Future; +use libsql_sys::EncryptionConfig; use tokio::task::JoinSet; use crate::connection::config::DatabaseConfig; @@ -49,6 +52,7 @@ impl PrimaryConfigurator { resolve_attach_path: ResolveNamespacePathFn, db_path: Arc, broadcaster: BroadcasterHandle, + encryption_config: Option, ) -> crate::Result { let mut join_set = JoinSet::new(); @@ -67,6 +71,7 @@ impl PrimaryConfigurator { resolve_attach_path, broadcaster, self.make_wal_manager.clone(), + encryption_config, ) .await?; let connection_maker = Arc::new(connection_maker); @@ -135,6 +140,7 @@ impl ConfigureNamespace for PrimaryConfigurator { resolve_attach_path, db_path.clone(), broadcaster, + self.base.encryption_config.clone(), ) .await { diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 61dd48b0bf..84ebadb897 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -169,6 +169,7 @@ impl ConfigureNamespace for ReplicaConfigurator { self.base.stats_sender.clone(), name.clone(), applied_frame_no_receiver.clone(), + self.base.encryption_config.clone(), ) .await?; diff --git a/libsql-server/src/namespace/configurator/schema.rs b/libsql-server/src/namespace/configurator/schema.rs index 98e679513a..f95c8abf51 100644 --- a/libsql-server/src/namespace/configurator/schema.rs +++ b/libsql-server/src/namespace/configurator/schema.rs @@ -68,6 +68,7 @@ impl ConfigureNamespace for SchemaConfigurator { resolve_attach_path, broadcaster, self.make_wal_manager.clone(), + self.base.encryption_config.clone(), ) .await?; diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 57916bb9a5..01a3d795d8 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -917,6 +917,7 @@ mod test { max_total_response_size: 100000000000, max_concurrent_connections: Arc::new(Semaphore::new(10)), max_concurrent_requests: 10000, + encryption_config: None, }; let primary_config = PrimaryExtraConfig { From 71c50e198fe0e61880cd2d5917af351c34a9f21e Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 6 Aug 2024 08:54:23 -0400 Subject: [PATCH 19/70] enable more windows CI --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 26eaba46cf..1903c0baff 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -159,8 +159,8 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo- - - name: check libsql remote - run: cargo check -p libsql --no-default-features -F remote + - name: build libsql all features + run: cargo build -p libsql --all-features # test-rust-wasm: # runs-on: ubuntu-latest From 07dc9b5b6d36e6eeae188c9a330df34bb99337ce Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:04:09 +0200 Subject: [PATCH 20/70] add LibsqlWalFooter --- libsql-wal/src/lib.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/libsql-wal/src/lib.rs b/libsql-wal/src/lib.rs index df104eda49..d46ade0010 100644 --- a/libsql-wal/src/lib.rs +++ b/libsql-wal/src/lib.rs @@ -15,6 +15,22 @@ const LIBSQL_MAGIC: u64 = u64::from_be_bytes(*b"LIBSQL\0\0"); const LIBSQL_PAGE_SIZE: u16 = 4096; const LIBSQL_WAL_VERSION: u16 = 1; +use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; +/// LibsqlFooter is located at the end of the libsql file. I contains libsql specific metadata, +/// while remaining fully compatible with sqlite (which just ignores that footer) +/// +/// The fields are in big endian to remain coherent with sqlite +#[derive(Copy, Clone, Debug, zerocopy::FromBytes, zerocopy::FromZeroes, zerocopy::AsBytes)] +#[repr(C)] +pub struct LibsqlFooter { + magic: bu64, + version: bu16, + /// Replication index checkpointed into this file. + /// only valid if there are no outstanding segments to checkpoint, since a checkpoint could be + /// partial. + replication_index: bu64, +} + #[cfg(any(debug_assertions, test))] pub mod test { use std::fs::OpenOptions; From 4069036362f41cc9015dab4a5698f6790e57e1f1 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:48:40 +0200 Subject: [PATCH 21/70] cancel query when request is dropped --- libsql-server/src/connection/libsql.rs | 49 ++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index aadff6190b..1acbbf2588 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -391,14 +391,43 @@ where ctx: RequestContext, builder: B, ) -> Result<(B, Program)> { + struct Bomb { + canceled: Arc, + defused: bool, + } + + impl Drop for Bomb { + fn drop(&mut self) { + if !self.defused { + tracing::debug!("cancelling request"); + self.canceled.store(true, Ordering::Relaxed); + } + } + } + + let canceled = { + let cancelled = self.inner.lock().canceled.clone(); + cancelled.store(false, Ordering::Relaxed); + cancelled + }; + + let mut bomb = Bomb { + canceled, + defused: false, + }; + PROGRAM_EXEC_COUNT.increment(1); check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; let conn = self.inner.clone(); - BLOCKING_RT + let ret = BLOCKING_RT .spawn_blocking(move || Connection::run(conn, pgm, builder)) .await - .unwrap() + .unwrap(); + + bomb.defused = true; + + ret } } @@ -413,6 +442,7 @@ pub(super) struct Connection { forced_rollback: bool, broadcaster: BroadcasterHandle, hooked: bool, + canceled: Arc, } fn update_stats( @@ -475,6 +505,19 @@ impl Connection { ); } + let canceled = Arc::new(AtomicBool::new(false)); + + conn.progress_handler(100, { + let canceled = canceled.clone(); + Some(move || { + let canceled = canceled.load(Ordering::Relaxed); + if canceled { + tracing::debug!("request canceled"); + } + canceled + }) + }); + let this = Self { conn, stats, @@ -486,6 +529,7 @@ impl Connection { forced_rollback: false, broadcaster, hooked: false, + canceled, }; for ext in extensions.iter() { @@ -795,6 +839,7 @@ mod test { forced_rollback: false, broadcaster: Default::default(), hooked: false, + canceled: Arc::new(false.into()), }; let conn = Arc::new(Mutex::new(conn)); From 5924766712c783136b634ff325238a0fb1858cae Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 13:15:12 +0200 Subject: [PATCH 22/70] write footer on checkpoint --- libsql-wal/src/lib.rs | 8 ++++---- libsql-wal/src/segment/list.rs | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/libsql-wal/src/lib.rs b/libsql-wal/src/lib.rs index d46ade0010..1c0dc63566 100644 --- a/libsql-wal/src/lib.rs +++ b/libsql-wal/src/lib.rs @@ -15,7 +15,7 @@ const LIBSQL_MAGIC: u64 = u64::from_be_bytes(*b"LIBSQL\0\0"); const LIBSQL_PAGE_SIZE: u16 = 4096; const LIBSQL_WAL_VERSION: u16 = 1; -use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; +use zerocopy::byteorder::big_endian::{U16 as bu16, U64 as bu64}; /// LibsqlFooter is located at the end of the libsql file. I contains libsql specific metadata, /// while remaining fully compatible with sqlite (which just ignores that footer) /// @@ -23,12 +23,12 @@ use zerocopy::byteorder::big_endian::{U64 as bu64, U16 as bu16}; #[derive(Copy, Clone, Debug, zerocopy::FromBytes, zerocopy::FromZeroes, zerocopy::AsBytes)] #[repr(C)] pub struct LibsqlFooter { - magic: bu64, - version: bu16, + pub magic: bu64, + pub version: bu16, /// Replication index checkpointed into this file. /// only valid if there are no outstanding segments to checkpoint, since a checkpoint could be /// partial. - replication_index: bu64, + pub replication_index: bu64, } #[cfg(any(debug_assertions, test))] diff --git a/libsql-wal/src/segment/list.rs b/libsql-wal/src/segment/list.rs index 25dfa3a32a..f1e3252161 100644 --- a/libsql-wal/src/segment/list.rs +++ b/libsql-wal/src/segment/list.rs @@ -15,6 +15,7 @@ use crate::error::Result; use crate::io::buf::{ZeroCopyBoxIoBuf, ZeroCopyBuf}; use crate::io::FileExt; use crate::segment::Frame; +use crate::{LibsqlFooter, LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION}; use super::Segment; @@ -157,6 +158,21 @@ where buf = read_buf.into_inner(); } + // update the footer at the end of the db file. + let footer = LibsqlFooter { + magic: LIBSQL_MAGIC.into(), + version: LIBSQL_WAL_VERSION.into(), + replication_index: last_replication_index.into(), + }; + + let footer_offset = size_after as usize * LIBSQL_PAGE_SIZE as usize; + let (_, ret) = db_file + .write_all_at_async(ZeroCopyBuf::new_init(footer), footer_offset as u64) + .await; + ret?; + + // todo: truncate if necessary + //// todo: make async db_file.sync_all()?; @@ -185,7 +201,7 @@ where Ok(Some(last_replication_index)) } - /// returnsstream pages from the sealed segment list, and what's the lowest replication index + /// returns a stream of pages from the sealed segment list, and what's the lowest replication index /// that was covered. If the returned index is less than start frame_no, the missing frames /// must be read somewhere else. pub async fn stream_pages_from<'a>( From d11ec010df9a049a17eb201785abaa6d8219f9d5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 15:38:56 +0200 Subject: [PATCH 23/70] downgrade debug to trace --- libsql-server/src/connection/libsql.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index 1acbbf2588..d98d6d0f82 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -399,7 +399,7 @@ where impl Drop for Bomb { fn drop(&mut self) { if !self.defused { - tracing::debug!("cancelling request"); + tracing::trace!("cancelling request"); self.canceled.store(true, Ordering::Relaxed); } } @@ -512,7 +512,7 @@ impl Connection { Some(move || { let canceled = canceled.load(Ordering::Relaxed); if canceled { - tracing::debug!("request canceled"); + tracing::trace!("request canceled"); } canceled }) From fc178de41fe56fd3ba845a9dc09bf71b6bf1b2d7 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 15:42:01 +0200 Subject: [PATCH 24/70] add query canceled metric --- libsql-server/src/connection/libsql.rs | 5 ++++- libsql-server/src/metrics.rs | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index d98d6d0f82..aa0604f03a 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -15,7 +15,9 @@ use tokio::sync::watch; use tokio::time::{Duration, Instant}; use crate::error::Error; -use crate::metrics::{DESCRIBE_COUNT, PROGRAM_EXEC_COUNT, VACUUM_COUNT, WAL_CHECKPOINT_COUNT}; +use crate::metrics::{ + DESCRIBE_COUNT, PROGRAM_EXEC_COUNT, QUERY_CANCELED, VACUUM_COUNT, WAL_CHECKPOINT_COUNT, +}; use crate::namespace::broadcasters::BroadcasterHandle; use crate::namespace::meta_store::MetaStoreHandle; use crate::namespace::ResolveNamespacePathFn; @@ -512,6 +514,7 @@ impl Connection { Some(move || { let canceled = canceled.load(Ordering::Relaxed); if canceled { + QUERY_CANCELED.increment(1); tracing::trace!("request canceled"); } canceled diff --git a/libsql-server/src/metrics.rs b/libsql-server/src/metrics.rs index a71b5ca979..1ac97435b3 100644 --- a/libsql-server/src/metrics.rs +++ b/libsql-server/src/metrics.rs @@ -153,3 +153,8 @@ pub static LISTEN_EVENTS_DROPPED: Lazy = Lazy::new(|| { describe_counter!(NAME, "Number of listen events dropped"); register_counter!(NAME) }); +pub static QUERY_CANCELED: Lazy = Lazy::new(|| { + const NAME: &str = "libsql_server_query_canceled"; + describe_counter!(NAME, "Number of canceled queries"); + register_counter!(NAME) +}); From 351e6ebfbec97e0a2c12f2d056f2876cb0058e38 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 18:52:57 +0400 Subject: [PATCH 25/70] add simple integration test --- libsql/tests/integration_tests.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 0f8e575949..cc239888d0 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -596,6 +596,22 @@ async fn debug_print_row() { ); } +#[tokio::test] +async fn fts5_invalid_tokenizer() { + let db = Database::open(":memory:").unwrap(); + let conn = db.connect().unwrap(); + assert!(conn.execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", + (), + ) + .await.is_err()); + assert!(conn.execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", + (), + ) + .await.is_err()); +} + #[cfg(feature = "serde")] #[tokio::test] async fn deserialize_row() { From 3e56d28d8614a070dd632c6d54b0cdd1d2e08579 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 16:57:49 +0400 Subject: [PATCH 26/70] fix potential crash in fts5 - see: https://sqlite.org/forum/forumpost/171bcc2bcd --- libsql-sqlite3/ext/fts5/fts5_tokenize.c | 60 ++++++++++++++----------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/libsql-sqlite3/ext/fts5/fts5_tokenize.c b/libsql-sqlite3/ext/fts5/fts5_tokenize.c index f12056170f..7e239b6ca5 100644 --- a/libsql-sqlite3/ext/fts5/fts5_tokenize.c +++ b/libsql-sqlite3/ext/fts5/fts5_tokenize.c @@ -1290,40 +1290,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; From 7ed14683a177ad913d63f06d5d6848846c6a0d00 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Wed, 7 Aug 2024 18:53:13 +0400 Subject: [PATCH 27/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 61 +++++++++++-------- libsql-ffi/bundled/bindings/bindgen.rs | 16 +++-- libsql-ffi/bundled/src/sqlite3.c | 61 +++++++++++-------- 3 files changed, 80 insertions(+), 58 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 529af0d52e..d7587cc38b 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -28,6 +28,7 @@ ** README.md ** configure ** configure.ac +** ext/fts5/fts5_tokenize.c ** ext/jni/src/org/sqlite/jni/capi/CollationNeededCallback.java ** ext/jni/src/org/sqlite/jni/capi/CommitHookCallback.java ** ext/jni/src/org/sqlite/jni/capi/PreupdateHookCallback.java @@ -259750,40 +259751,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; diff --git a/libsql-ffi/bundled/bindings/bindgen.rs b/libsql-ffi/bundled/bindings/bindgen.rs index 9dec505c10..cc73807f33 100644 --- a/libsql-ffi/bundled/bindings/bindgen.rs +++ b/libsql-ffi/bundled/bindings/bindgen.rs @@ -940,7 +940,7 @@ extern "C" { extern "C" { pub fn sqlite3_vmprintf( arg1: *const ::std::os::raw::c_char, - arg2: va_list, + arg2: *mut __va_list_tag, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -956,7 +956,7 @@ extern "C" { arg1: ::std::os::raw::c_int, arg2: *mut ::std::os::raw::c_char, arg3: *const ::std::os::raw::c_char, - arg4: va_list, + arg4: *mut __va_list_tag, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -2503,7 +2503,7 @@ extern "C" { pub fn sqlite3_str_vappendf( arg1: *mut sqlite3_str, zFormat: *const ::std::os::raw::c_char, - arg2: va_list, + arg2: *mut __va_list_tag, ); } extern "C" { @@ -3524,4 +3524,12 @@ extern "C" { extern "C" { pub static sqlite3_wal_manager: libsql_wal_manager; } -pub type __builtin_va_list = *mut ::std::os::raw::c_char; +pub type __builtin_va_list = [__va_list_tag; 1usize]; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __va_list_tag { + pub gp_offset: ::std::os::raw::c_uint, + pub fp_offset: ::std::os::raw::c_uint, + pub overflow_arg_area: *mut ::std::os::raw::c_void, + pub reg_save_area: *mut ::std::os::raw::c_void, +} diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 529af0d52e..d7587cc38b 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -28,6 +28,7 @@ ** README.md ** configure ** configure.ac +** ext/fts5/fts5_tokenize.c ** ext/jni/src/org/sqlite/jni/capi/CollationNeededCallback.java ** ext/jni/src/org/sqlite/jni/capi/CommitHookCallback.java ** ext/jni/src/org/sqlite/jni/capi/PreupdateHookCallback.java @@ -259750,40 +259751,46 @@ static int fts5TriCreate( Fts5Tokenizer **ppOut ){ int rc = SQLITE_OK; - TrigramTokenizer *pNew = (TrigramTokenizer*)sqlite3_malloc(sizeof(*pNew)); - UNUSED_PARAM(pUnused); - if( pNew==0 ){ - rc = SQLITE_NOMEM; + TrigramTokenizer *pNew = 0; + + if( nArg%2 ){ + rc = SQLITE_ERROR; }else{ - int i; - pNew->bFold = 1; - pNew->iFoldParam = 0; - for(i=0; rc==SQLITE_OK && ibFold = 1; + pNew->iFoldParam = 0; + for(i=0; rc==SQLITE_OK && ibFold = (zArg[0]=='0'); + } + }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ + if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ + rc = SQLITE_ERROR; + }else{ + pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; + } }else{ - pNew->bFold = (zArg[0]=='0'); - } - }else if( 0==sqlite3_stricmp(azArg[i], "remove_diacritics") ){ - if( (zArg[0]!='0' && zArg[0]!='1' && zArg[0]!='2') || zArg[1] ){ rc = SQLITE_ERROR; - }else{ - pNew->iFoldParam = (zArg[0]!='0') ? 2 : 0; } - }else{ - rc = SQLITE_ERROR; } - } - if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ - rc = SQLITE_ERROR; - } + if( pNew->iFoldParam!=0 && pNew->bFold==0 ){ + rc = SQLITE_ERROR; + } - if( rc!=SQLITE_OK ){ - fts5TriDelete((Fts5Tokenizer*)pNew); - pNew = 0; + if( rc!=SQLITE_OK ){ + fts5TriDelete((Fts5Tokenizer*)pNew); + pNew = 0; + } } } *ppOut = (Fts5Tokenizer*)pNew; From 9595315d30ee56d9845e884fd3fae768352967b9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 7 Aug 2024 20:42:27 +0200 Subject: [PATCH 28/70] init cancel bomb berfore query exec --- libsql-server/src/connection/libsql.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index aa0604f03a..9896164e55 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -413,14 +413,15 @@ where cancelled }; + PROGRAM_EXEC_COUNT.increment(1); + + check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; + + // create the bomb right before spawning the blocking task. let mut bomb = Bomb { canceled, defused: false, }; - - PROGRAM_EXEC_COUNT.increment(1); - - check_program_auth(&ctx, &pgm, &self.inner.lock().config_store.get())?; let conn = self.inner.clone(); let ret = BLOCKING_RT .spawn_blocking(move || Connection::run(conn, pgm, builder)) From 0d411057ae5d42bffd1e451bfc2e5aa44ab72042 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 00:33:27 +0400 Subject: [PATCH 29/70] cargo fmt --- libsql/tests/integration_tests.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index cc239888d0..cdb0a985c3 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -600,16 +600,20 @@ async fn debug_print_row() { async fn fts5_invalid_tokenizer() { let db = Database::open(":memory:").unwrap(); let conn = db.connect().unwrap(); - assert!(conn.execute( - "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", - (), - ) - .await.is_err()); - assert!(conn.execute( - "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", - (), - ) - .await.is_err()); + assert!(conn + .execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram case_sensitive ')", + (), + ) + .await + .is_err()); + assert!(conn + .execute( + "CREATE VIRTUAL TABLE t USING fts5(s, tokenize='trigram remove_diacritics ')", + (), + ) + .await + .is_err()); } #[cfg(feature = "serde")] From 4085a0d35edd59c0c75bb917f4236dcdeeb0f574 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Wed, 7 Aug 2024 17:40:39 -0400 Subject: [PATCH 30/70] libsql: downgrade failed prefetch log to debug --- libsql/src/replication/remote_client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index dbab056938..d0052f50d9 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -135,7 +135,7 @@ impl RemoteClient { (hello_fut.await, None) }; self.prefetched_batch_log_entries = if let Ok(true) = hello.0 { - tracing::warn!( + tracing::debug!( "Frames prefetching failed because of new session token returned by handshake" ); None From b0bc6eb2f5686b2e1706dd06b51405e8f0257ecd Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 12:13:13 +0400 Subject: [PATCH 31/70] publish sqld debug builds to the separate image name --- .github/workflows/publish-server.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-server.yml b/.github/workflows/publish-server.yml index 10820457b8..d957195e40 100644 --- a/.github/workflows/publish-server.yml +++ b/.github/workflows/publish-server.yml @@ -118,7 +118,7 @@ jobs: context: . platforms: ${{ env.platform }} labels: ${{ steps.meta.outputs.labels }} - outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-debug,push-by-digest=true,name-canonical=true,push=true build-args: | BUILD_DEBUG=true - From 51b1b490545524e846e52479743054fbbe2e1660 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:23:09 +0400 Subject: [PATCH 32/70] remove digests artifacts from debug build step --- .github/workflows/publish-server.yml | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/.github/workflows/publish-server.yml b/.github/workflows/publish-server.yml index d957195e40..e1973fe47c 100644 --- a/.github/workflows/publish-server.yml +++ b/.github/workflows/publish-server.yml @@ -121,20 +121,6 @@ jobs: outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-debug,push-by-digest=true,name-canonical=true,push=true build-args: | BUILD_DEBUG=true - - - name: Export digest - run: | - mkdir -p /tmp/digests - digest="${{ steps.build.outputs.digest }}" - touch "/tmp/digests/${digest#sha256:}" - - - name: Upload digest - uses: actions/upload-artifact@v4 - with: - name: digests-debug-${{ env.PLATFORM_PAIR }} - path: /tmp/digests/* - if-no-files-found: error - retention-days: 1 build-arm64: permissions: write-all From ec7bca5a20eb413ba658eef04e643d94f6fa562a Mon Sep 17 00:00:00 2001 From: wyhaya Date: Thu, 8 Aug 2024 12:24:36 +0800 Subject: [PATCH 33/70] Fix JSON f64 precision --- libsql/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index efae2abea3..3d65f71c73 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -20,7 +20,7 @@ hyper = { workspace = true, features = ["client", "stream"], optional = true } hyper-rustls = { version = "0.25", features = ["webpki-roots"], optional = true } base64 = { version = "0.21", optional = true } serde = { version = "1", features = ["derive"], optional = true } -serde_json = { version = "1", optional = true } +serde_json = { version = "1", features = ["float_roundtrip"], optional = true } async-trait = "0.1" bitflags = { version = "2.4.0", optional = true } tower = { workspace = true, features = ["util"], optional = true } From 5eeba4330901f022963bf50cf35e3e7120cf1a09 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:18:51 +0400 Subject: [PATCH 34/70] improve random row selection --- libsql-sqlite3/src/vectordiskann.c | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95d473b630..8804aee119 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -442,6 +442,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -509,6 +510,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -518,9 +520,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } From 4b3e7e7544e68a03ec04f38a686791bb76886fd3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:20:41 +0400 Subject: [PATCH 35/70] fix random row selection query to have db name --- libsql-sqlite3/src/vectordiskann.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 8804aee119..fc39e00d30 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -574,8 +574,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; From 0b41b5aa7c93639218caf6f765d6c4d26e0b0567 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:33:31 +0400 Subject: [PATCH 36/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 28 +++++++++++++++++-- libsql-ffi/bundled/src/sqlite3.c | 28 +++++++++++++++++-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index d7587cc38b..15d09606fb 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212002,6 +212002,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -212069,6 +212070,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -212078,9 +212080,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } @@ -212110,8 +212134,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index d7587cc38b..15d09606fb 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212002,6 +212002,7 @@ int diskAnnCreateIndex( int type, dims; u64 maxNeighborsParam, blockSizeBytes; char *zSql; + const char *zRowidColumnName; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...) if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){ @@ -212069,6 +212070,7 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "index_key"; }else{ zSql = sqlite3MPrintf( db, @@ -212078,9 +212080,31 @@ int diskAnnCreateIndex( columnSqlDefs, columnSqlNames ); + zRowidColumnName = "rowid"; } rc = sqlite3_exec(db, zSql, 0, 0, 0); sqlite3DbFree(db, zSql); + if( rc != SQLITE_OK ){ + return rc; + } + /* + * vector blobs are usually pretty huge (more than a page size, for example, node block for 1024d f32 embeddings with 1bit compression will occupy ~20KB) + * in this case, main table B-Tree takes on redundant shape where all leaf nodes has only 1 cell + * + * as we have a query which selects random row using OFFSET/LIMIT trick - we will need to read all these leaf nodes pages just to skip them + * so, in order to remove this overhead for random row selection - we creating an index with just single column used + * in this case B-Tree leafs will be full of rowids and the overhead for page reads will be very small + */ + zSql = sqlite3MPrintf( + db, + "CREATE INDEX IF NOT EXISTS \"%w\".%s_shadow_idx ON %s_shadow (%s)", + zDbSName, + zIdxName, + zIdxName, + zRowidColumnName + ); + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3DbFree(db, zSql); return rc; } @@ -212110,8 +212134,8 @@ static int diskAnnSelectRandomShadowRow(const DiskAnnIndex *pIndex, u64 *pRowid) zSql = sqlite3MPrintf( pIndex->db, - "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM %s), 1)", - pIndex->zDbSName, pIndex->zShadow, pIndex->zShadow + "SELECT rowid FROM \"%w\".%s LIMIT 1 OFFSET ABS(RANDOM()) %% MAX((SELECT COUNT(*) FROM \"%w\".%s), 1)", + pIndex->zDbSName, pIndex->zShadow, pIndex->zDbSName, pIndex->zShadow ); if( zSql == NULL ){ rc = SQLITE_NOMEM_BKPT; From f80444ac450d19754bf11de98753be6f43ca8332 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:50:29 +0400 Subject: [PATCH 37/70] fix test --- libsql-sqlite3/test/libsql_vector_index.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7308b2d93f..c1a270e4da 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -140,7 +140,7 @@ do_execsql_test vector-sql { INSERT INTO t_sql VALUES(vector('[1,2,3]')), (vector('[2,3,4]')); SELECT sql FROM sqlite_master WHERE name LIKE '%t_sql%'; SELECT name FROM libsql_vector_meta_shadow WHERE name = 't_sql_idx'; -} {{CREATE TABLE t_sql( v FLOAT32(3))} {CREATE TABLE t_sql_idx_shadow (index_key INTEGER , data BLOB, PRIMARY KEY (index_key))} {CREATE INDEX t_sql_idx ON t_sql( libsql_vector_idx(v) )} {t_sql_idx}} +} {{CREATE TABLE t_sql( v FLOAT32(3))} {CREATE TABLE t_sql_idx_shadow (index_key INTEGER , data BLOB, PRIMARY KEY (index_key))} {CREATE INDEX t_sql_idx_shadow_idx ON t_sql_idx_shadow (index_key)} {CREATE INDEX t_sql_idx ON t_sql( libsql_vector_idx(v) )} {t_sql_idx}} do_execsql_test vector-drop-index { CREATE TABLE t_index_drop( v FLOAT32(3)); From 83d029d7e922618c845f85db5e8a2d09d7f3b188 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 5 Aug 2024 17:13:19 +0400 Subject: [PATCH 38/70] cleanup code a bit in order to simplify working with vector of different types --- libsql-sqlite3/src/vector.c | 69 ++++++++++++++++++++++-------- libsql-sqlite3/src/vectorInt.h | 16 ++----- libsql-sqlite3/src/vectordiskann.c | 16 +++++-- libsql-sqlite3/src/vectorfloat32.c | 38 ++-------------- libsql-sqlite3/src/vectorfloat64.c | 69 ++---------------------------- 5 files changed, 76 insertions(+), 132 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index d32819cd00..bef34140bf 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -252,11 +252,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -384,20 +402,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -412,18 +457,6 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return 0; } -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ switch (pVector->type) { case VECTOR_TYPE_FLOAT32: @@ -470,7 +503,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 8c9138b94f..64703b447f 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -65,13 +65,6 @@ size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); - /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) */ @@ -91,16 +84,15 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); + +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95d473b630..54f3fba0b3 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -47,6 +47,7 @@ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "math.h" @@ -1490,6 +1491,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -1536,9 +1538,17 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = VECTOR_TYPE_1BIT; + pIndex->nEdgeVectorSize = vectorDataSize(VECTOR_TYPE_1BIT, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; return SQLITE_OK; diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 8aeae2eb23..5d6641991c 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -94,26 +94,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -220,32 +200,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index ced2be1843..1d29c9c3d6 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -98,57 +98,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -227,32 +176,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ From 39e30ead5e755c25971c8d05731e5778bbfaece8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 5 Aug 2024 18:14:58 +0400 Subject: [PATCH 39/70] add 1bit vector type --- libsql-sqlite3/Makefile.in | 6 +- libsql-sqlite3/src/vector.c | 9 +++ libsql-sqlite3/src/vector1bit.c | 110 ++++++++++++++++++++++++++++ libsql-sqlite3/src/vectorIndex.c | 25 ++++--- libsql-sqlite3/src/vectorIndexInt.h | 42 ++++++----- libsql-sqlite3/src/vectorInt.h | 18 +++-- libsql-sqlite3/src/vectordiskann.c | 21 ++++-- libsql-sqlite3/tool/mksqlite3c.tcl | 1 + 8 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 libsql-sqlite3/src/vector1bit.c diff --git a/libsql-sqlite3/Makefile.in b/libsql-sqlite3/Makefile.in index 4520dda0d2..0afadd458f 100644 --- a/libsql-sqlite3/Makefile.in +++ b/libsql-sqlite3/Makefile.in @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \ sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \ table.lo threads.lo tokenize.lo treeview.lo trigger.lo \ update.lo userauth.lo upsert.lo util.lo vacuum.lo \ - vector.lo vectorfloat32.lo vectorfloat64.lo \ + vector.lo vectorfloat32.lo vectorfloat64.lo vector1bit.lo \ vectorIndex.lo vectordiskann.lo vectorvtab.lo \ vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \ vdbetrace.lo vdbevtab.lo \ @@ -302,6 +302,7 @@ SRC = \ $(TOP)/src/util.c \ $(TOP)/src/vacuum.c \ $(TOP)/src/vector.c \ + $(TOP)/src/vector1bit.c \ $(TOP)/src/vectorInt.h \ $(TOP)/src/vectorfloat32.c \ $(TOP)/src/vectorfloat64.c \ @@ -1138,6 +1139,9 @@ vacuum.lo: $(TOP)/src/vacuum.c $(HDR) vector.lo: $(TOP)/src/vector.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector.c +vector1bit.lo: $(TOP)/src/vector1bit.c $(HDR) + $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector1bit.c + vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR) $(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index bef34140bf..9d37dbb4a8 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -41,6 +41,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -111,6 +113,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -381,6 +385,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -451,6 +458,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c new file mode 100644 index 0000000000..c8da5496d2 --- /dev/null +++ b/libsql-sqlite3/src/vector1bit.c @@ -0,0 +1,110 @@ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +#include "sqliteInt.h" + +#include "vectorInt.h" + +#include + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + float *elems = pVector->data; + unsigned char *pPtr = pBlob; + size_t len = 0; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < pVector->dims; i++){ + elems[i] = pPtr[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int sum = 0; + u8 *e1 = v1->data; + u8 *e2 = v2->data; + int i; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < v1->dims; i++){ + sum += BitsCount[e1[i]&e2[i]]; + } + return sum; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index d8b3497781..7ad42a00ba 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -396,13 +396,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_METRIC_TYPE_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -802,7 +803,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -914,9 +915,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); diff --git a/libsql-sqlite3/src/vectorIndexInt.h b/libsql-sqlite3/src/vectorIndexInt.h index 8f73091bb1..bb6c085e94 100644 --- a/libsql-sqlite3/src/vectorIndexInt.h +++ b/libsql-sqlite3/src/vectorIndexInt.h @@ -100,43 +100,45 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* * 1 - initial version */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_DEFAULT 1 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 + +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -218,7 +220,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 64703b447f..d6f9f36635 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -24,6 +24,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -48,8 +49,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -61,9 +63,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -72,6 +75,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 54f3fba0b3..cce2f090ff 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -437,10 +437,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -477,11 +478,19 @@ int diskAnnCreateIndex( if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ return SQLITE_ERROR; } } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -1544,8 +1553,8 @@ int diskAnnOpenIndex( pIndex->nEdgeVectorType = pIndex->nNodeVectorType; pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ - pIndex->nEdgeVectorType = VECTOR_TYPE_1BIT; - pIndex->nEdgeVectorSize = vectorDataSize(VECTOR_TYPE_1BIT, pIndex->nVectorDims); + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); }else{ return SQLITE_ERROR; } diff --git a/libsql-sqlite3/tool/mksqlite3c.tcl b/libsql-sqlite3/tool/mksqlite3c.tcl index 31e6d84a57..3a04459e31 100644 --- a/libsql-sqlite3/tool/mksqlite3c.tcl +++ b/libsql-sqlite3/tool/mksqlite3c.tcl @@ -468,6 +468,7 @@ set flist { json.c vector.c + vector1bit.c vectordiskann.c vectorfloat32.c vectorfloat64.c From 2e696fe5d58614559fcf976ab7e526bebc2c1065 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 11:02:22 +0400 Subject: [PATCH 40/70] restructure search a bit in order to support compressed edges --- libsql-sqlite3/src/vectordiskann.c | 168 ++++++++++++++++++++--------- 1 file changed, 118 insertions(+), 50 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index cce2f090ff..9b302a4dda 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -98,14 +98,19 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + const Vector *pNodeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + const Vector *pEdgeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ + float *aDistances; /* array of distances to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -805,6 +810,53 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -841,16 +893,21 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + pCtx->pNodeQuery = pQuery; + pCtx->pEdgeQuery = pQuery; pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -861,6 +918,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -884,6 +947,8 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); } // check if we visited this node earlier @@ -925,7 +990,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -934,56 +1001,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; pCtx->nUnvisited++; } @@ -1131,7 +1193,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->pNodeQuery, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -1148,8 +1210,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -1177,13 +1240,18 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->pNodeQuery != pCtx->pEdgeQuery ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->pNodeQuery); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; @@ -1195,7 +1263,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pEdgeQuery, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -1272,7 +1340,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -1281,7 +1349,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -1289,9 +1357,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -1327,7 +1395,7 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; From 1a9cab9163a53f66136c6ce5aab3fa3bb7c71ae9 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 18:27:39 +0400 Subject: [PATCH 41/70] 1bit quantized embeddings search: somehow working version --- libsql-sqlite3/src/vector.c | 32 +++++- libsql-sqlite3/src/vector1bit.c | 17 ++- libsql-sqlite3/src/vectorIndex.c | 8 +- libsql-sqlite3/src/vectorInt.h | 4 +- libsql-sqlite3/src/vectordiskann.c | 166 +++++++++++++++++++++-------- 5 files changed, 166 insertions(+), 61 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 9d37dbb4a8..6c3c8d83fe 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -74,10 +74,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -479,6 +480,31 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] < 0 ){ + bitData[i / 8] &= ~(1 << (i & 7)); + }else{ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index c8da5496d2..76fc964d7c 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -56,17 +56,16 @@ size_t vector1BitSerializeToBlob( unsigned char *pBlob, size_t nBlobSize ){ - float *elems = pVector->data; - unsigned char *pPtr = pBlob; - size_t len = 0; + u8 *elems = pVector->data; + u8 *pPtr = pBlob; unsigned i; assert( pVector->type == VECTOR_TYPE_1BIT ); assert( pVector->dims <= MAX_VECTOR_SZ ); assert( nBlobSize >= (pVector->dims + 7) / 8 ); - for(i = 0; i < pVector->dims; i++){ - elems[i] = pPtr[i]; + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; } return (pVector->dims + 7) / 8; } @@ -92,7 +91,7 @@ static int BitsCount[256] = { }; int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ - int sum = 0; + int diff = 0; u8 *e1 = v1->data; u8 *e2 = v2->data; int i; @@ -101,10 +100,10 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ assert( v1->type == VECTOR_TYPE_1BIT ); assert( v2->type == VECTOR_TYPE_1BIT ); - for(i = 0; i < v1->dims; i++){ - sum += BitsCount[e1[i]&e2[i]]; + for(i = 0; i < v1->dims; i += 8){ + diff += BitsCount[e1[i/8] ^ e2[i/8]]; } - return sum; + return diff; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 7ad42a00ba..5a13b4ea60 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -396,10 +396,10 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "compress_neighbors", VECTOR_METRIC_TYPE_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index d6f9f36635..84cf9c0d1f 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -102,11 +102,13 @@ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 9b302a4dda..d9bb37ba35 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -80,9 +80,18 @@ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -98,8 +107,7 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pNodeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - const Vector *pEdgeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ float *aDistances; /* array of distances to the query vector */ unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ @@ -316,7 +324,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -337,8 +345,8 @@ void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdg vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -470,19 +478,6 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); - maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); - if( maxNeighborsParam == 0 ){ - // 3 D**(1/2) gives good recall values (90%+) - // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); - } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); - if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ - return SQLITE_ERROR; - } - if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ - return SQLITE_ERROR; - } metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); if( metric == 0 ){ metric = VECTOR_METRIC_TYPE_COS; @@ -495,6 +490,23 @@ int diskAnnCreateIndex( *pzErrMsg = "1-bit compression available only for cosine metric"; return SQLITE_ERROR; } + if( neighbours == 0 ){ + neighbours = type; + } + + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); + if( maxNeighborsParam == 0 ){ + // 3 D**(1/2) gives good recall values (90%+) + // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); + } + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); + if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ + return SQLITE_ERROR; + } + if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ + return SQLITE_ERROR; + } if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ @@ -814,6 +826,36 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ ** Generic utilities **************************************************************************/ +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ int i; #ifdef SQLITE_DEBUG @@ -893,9 +935,7 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ - pCtx->pNodeQuery = pQuery; - pCtx->pEdgeQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -907,6 +947,11 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, in pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } @@ -949,6 +994,7 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ sqlite3_free(pCtx->aDistances); sqlite3_free(pCtx->aTopCandidates); sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -1075,7 +1121,13 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -1083,7 +1135,10 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; @@ -1095,8 +1150,8 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -1114,12 +1169,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -1129,7 +1186,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -1143,8 +1200,8 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -1193,7 +1250,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pNodeQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -1246,8 +1303,8 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u nEdges = nodeBinEdges(pIndex, pCandidateBlob); // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it - if( pCtx->pNodeQuery != pCtx->pEdgeQuery ){ - distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->pNodeQuery); + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); } diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); @@ -1263,7 +1320,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pEdgeQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -1340,7 +1397,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -1383,6 +1440,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -1395,12 +1455,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -1438,28 +1510,31 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -1470,6 +1545,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -1628,6 +1705,7 @@ int diskAnnOpenIndex( } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } From cd9cea34ec7944880bd4bbe3a54f39fbc5f3cc07 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Tue, 6 Aug 2024 19:17:48 +0400 Subject: [PATCH 42/70] move intrinsics to the utils --- libsql-sqlite3/src/sqliteInt.h | 2 ++ libsql-sqlite3/src/util.c | 37 ++++++++++++++++++++++++++++ libsql-sqlite3/src/vector1bit.c | 39 +++++++++++------------------- libsql-sqlite3/src/vectordiskann.c | 2 +- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index e2fd32d3c4..891cf79fe4 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -5321,6 +5321,8 @@ int sqlite3AddInt64(i64*,i64); int sqlite3SubInt64(i64*,i64); int sqlite3MulInt64(i64*,i64); int sqlite3AbsInt32(int); +int sqlite3PopCount32(u32); +int sqlite3PopCount64(u64); #ifdef SQLITE_ENABLE_8_3_NAMES void sqlite3FileSuffix3(const char*, char*); #else diff --git a/libsql-sqlite3/src/util.c b/libsql-sqlite3/src/util.c index 207b901bad..0bce77e65a 100644 --- a/libsql-sqlite3/src/util.c +++ b/libsql-sqlite3/src/util.c @@ -1542,6 +1542,43 @@ int sqlite3SafetyCheckSickOrOk(sqlite3 *db){ } } + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int sqlite3PopCount64(u64 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcountll(a); +#else + return sqlite3PopCount32(a >> 32) + sqlite3PopCount32(a & 0xffffffff); +#endif +} + /* ** Attempt to add, subtract, or multiply the 64-bit signed value iB against ** the other 64-bit signed integer at *pA and store the result in *pA. diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 76fc964d7c..89b0d09d90 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -70,38 +70,27 @@ size_t vector1BitSerializeToBlob( return (pVector->dims + 7) / 8; } -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int diff = 0; - u8 *e1 = v1->data; - u8 *e2 = v2->data; - int i; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; assert( v1->dims == v2->dims ); assert( v1->type == VECTOR_TYPE_1BIT ); assert( v2->type == VECTOR_TYPE_1BIT ); - for(i = 0; i < v1->dims; i += 8){ - diff += BitsCount[e1[i/8] ^ e2[i/8]]; + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); } return diff; } diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index d9bb37ba35..c29b16e35d 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -54,7 +54,7 @@ #include "sqliteInt.h" #include "vectorIndexInt.h" -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else From eac5d90807181de75d7e8f8c0e5d3b26b0633d52 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:15:39 +0400 Subject: [PATCH 43/70] move utils back to the vector1bit to simplify inlining for compiler --- libsql-sqlite3/src/sqliteInt.h | 2 -- libsql-sqlite3/src/util.c | 37 --------------------------------- libsql-sqlite3/src/vector.c | 4 +--- libsql-sqlite3/src/vector1bit.c | 28 +++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 42 deletions(-) diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index 891cf79fe4..e2fd32d3c4 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -5321,8 +5321,6 @@ int sqlite3AddInt64(i64*,i64); int sqlite3SubInt64(i64*,i64); int sqlite3MulInt64(i64*,i64); int sqlite3AbsInt32(int); -int sqlite3PopCount32(u32); -int sqlite3PopCount64(u64); #ifdef SQLITE_ENABLE_8_3_NAMES void sqlite3FileSuffix3(const char*, char*); #else diff --git a/libsql-sqlite3/src/util.c b/libsql-sqlite3/src/util.c index 0bce77e65a..207b901bad 100644 --- a/libsql-sqlite3/src/util.c +++ b/libsql-sqlite3/src/util.c @@ -1542,43 +1542,6 @@ int sqlite3SafetyCheckSickOrOk(sqlite3 *db){ } } - -// [sum(map(int, bin(i)[2:])) for i in range(256)] -static int BitsCount[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, -}; - -int sqlite3PopCount32(u32 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcount(a); -#else - return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; -#endif -} - -int sqlite3PopCount64(u64 a){ -#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) - return __builtin_popcountll(a); -#else - return sqlite3PopCount32(a >> 32) + sqlite3PopCount32(a & 0xffffffff); -#endif -} - /* ** Attempt to add, subtract, or multiply the 64-bit signed value iB against ** the other 64-bit signed integer at *pA and store the result in *pA. diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 6c3c8d83fe..73b84b047e 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -494,9 +494,7 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){ bitData[i / 8] = 0; } for(i = 0; i < pFrom->dims; i++){ - if( floatData[i] < 0 ){ - bitData[i / 8] &= ~(1 << (i & 7)); - }else{ + if( floatData[i] > 0 ){ bitData[i / 8] |= (1 << (i & 7)); } } diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 89b0d09d90..66da59f76a 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -70,6 +70,34 @@ size_t vector1BitSerializeToBlob( return (pVector->dims + 7) / 8; } +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ int diff = 0; u8 *e1U8 = v1->data; From 0ec0147979434c4712a2f6c225bec7335f6af973 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:18:34 +0400 Subject: [PATCH 44/70] extend binary format and store distance to edges in node blocks --- libsql-sqlite3/src/vectorIndexInt.h | 10 +++-- libsql-sqlite3/src/vectordiskann.c | 62 +++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndexInt.h b/libsql-sqlite3/src/vectorIndexInt.h index bb6c085e94..e65df4d515 100644 --- a/libsql-sqlite3/src/vectorIndexInt.h +++ b/libsql-sqlite3/src/vectorIndexInt.h @@ -73,10 +73,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -102,9 +102,11 @@ typedef u8 MetricType; /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ #define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ #define VECTOR_INDEX_TYPE_PARAM_ID 2 diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index c29b16e35d..0853120797 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -129,6 +129,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -145,6 +149,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -333,13 +344,18 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( @@ -356,7 +372,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -371,7 +387,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -390,6 +406,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -424,6 +441,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -434,8 +452,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -1126,7 +1144,8 @@ static int diskAnnReplaceEdgeIdx( BlobSpot *pNodeBlob, u64 newRowid, VectorPair *pNewVector, - VectorPair *pPlaceholder + VectorPair *pPlaceholder, + float *pNodeToNew ) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; @@ -1139,19 +1158,23 @@ static int diskAnnReplaceEdgeIdx( // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); - nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -1186,7 +1209,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintEdgeVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -1194,13 +1217,16 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); @@ -1315,7 +1341,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } @@ -1512,15 +1538,16 @@ int diskAnnInsert( for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ Vector nodeVector; int iReplace; + float nodeToNew; nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); loadVectorPair(&vCandidate, &nodeVector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, vCandidate.pEdge); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } @@ -1528,12 +1555,13 @@ int diskAnnInsert( loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, vInsert.pEdge); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); @@ -1598,7 +1626,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; From 657ce07ee40e62a50945d2e927d113c993908b19 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 13:38:18 +0400 Subject: [PATCH 45/70] fix comment a little bit --- libsql-sqlite3/src/vectordiskann.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 0853120797..88151bf1df 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -107,9 +107,9 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - float *aDistances; /* array of distances to the query vector */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ From 11e6a9326449af84a3e940d89753ec349e8f2cd9 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:24:38 +0400 Subject: [PATCH 46/70] add simple test --- libsql-sqlite3/test/libsql_vector_index.test | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7308b2d93f..8054dfaf0b 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -268,9 +268,19 @@ do_execsql_test vector-transaction { SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2); } {3 4 1 2} +do_execsql_test vector-1bit { + CREATE TABLE t_1bit( v FLOAT32(3) ); + CREATE INDEX t_1bit_idx ON t_1bit( libsql_vector_idx(v, 'compress_neighbors=1bit') ); + INSERT INTO t_1bit VALUES (vector('[-1,-1,1]')); + INSERT INTO t_1bit VALUES (vector('[-1,1,-1.5]')); + INSERT INTO t_1bit VALUES (vector('[1,-1,-1]')); + INSERT INTO t_1bit VALUES (vector('[-1,-1,-1]')); + SELECT rowid FROM vector_top_k('t_1bit_idx', vector('[1,-1,-1]'), 4); +} {3 4 2 1} + do_execsql_test vector-all-params { CREATE TABLE t_all_params ( emb FLOAT32(2) ); - CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6')); + CREATE INDEX t_all_params_idx ON t_all_params(libsql_vector_idx(emb, 'type=diskann', 'metric=cos', 'alpha=1.2', 'search_l=200', 'insert_l=70', 'max_neighbors=6', 'compress_neighbors=1bit')); INSERT INTO t_all_params VALUES (vector('[1,2]')), (vector('[3,4]')); SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} From 93caa2736702714e1df07152e550ca52202feadc Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 14:55:51 +0400 Subject: [PATCH 47/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 836 ++++++++++++------ libsql-ffi/bundled/src/sqlite3.c | 836 ++++++++++++------ 2 files changed, 1144 insertions(+), 528 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index d7587cc38b..eff02cf87a 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85246,6 +85246,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -85270,8 +85271,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -85283,16 +85285,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); - -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -85301,6 +85297,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ @@ -85313,22 +85314,23 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); + +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); @@ -85410,10 +85412,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -85437,43 +85439,47 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 + +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -85555,7 +85561,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); @@ -210981,6 +210987,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -211012,10 +211020,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -211051,6 +211060,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -211192,11 +211203,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -211303,6 +211332,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -211324,20 +211356,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -211346,18 +211405,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } @@ -211377,6 +211426,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] > 0 ){ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ @@ -211410,7 +211482,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } @@ -211557,6 +211629,135 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ +/************** Begin file vector1bit.c **************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -211607,13 +211808,14 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ /* #include "sqliteInt.h" */ /* #include "vectorIndexInt.h" */ -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else @@ -211639,9 +211841,18 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -211657,14 +211868,18 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -211675,6 +211890,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -211691,6 +211910,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -211870,7 +212096,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -211879,20 +212105,25 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -211902,7 +212133,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -211917,7 +212148,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -211936,6 +212167,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -211970,6 +212202,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -211980,8 +212213,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -211996,10 +212229,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -212023,24 +212257,36 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ + return SQLITE_ERROR; + } + } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( neighbours == 0 ){ + neighbours = type; + } + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); if( maxNeighborsParam == 0 ){ // 3 D**(1/2) gives good recall values (90%+) // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ return SQLITE_ERROR; } if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ - return SQLITE_ERROR; - } - } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -212355,6 +212601,83 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -212391,16 +212714,24 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -212411,6 +212742,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -212434,6 +212771,9 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -212475,7 +212815,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -212484,56 +212826,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; - } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); + } pCtx->nUnvisited++; } @@ -212563,7 +212900,14 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder, + float *pNodeToNew +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -212571,20 +212915,27 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -212602,12 +212953,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -212617,7 +212970,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -212625,14 +212978,17 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -212681,7 +213037,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -212698,8 +213054,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -212727,25 +213084,30 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -212822,7 +213184,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -212831,7 +213193,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -212839,9 +213201,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -212865,6 +213227,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -212877,12 +213242,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -212920,28 +213297,33 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; + float nodeToNew; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -212952,6 +213334,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -213003,7 +213387,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; @@ -213050,6 +213434,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -213096,11 +213481,20 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } @@ -213216,26 +213610,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -213342,32 +213716,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -213474,57 +213838,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -213603,32 +213916,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -214033,13 +214336,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -214439,7 +214743,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -214551,9 +214855,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index d7587cc38b..eff02cf87a 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85246,6 +85246,7 @@ typedef u32 VectorDims; */ #define VECTOR_TYPE_FLOAT32 1 #define VECTOR_TYPE_FLOAT64 2 +#define VECTOR_TYPE_1BIT 3 #define VECTOR_FLAGS_STATIC 1 @@ -85270,8 +85271,9 @@ void vectorInit(Vector *, VectorType, VectorDims, void *); * Dumps vector on the console (used only for debugging) */ void vectorDump (const Vector *v); -void vectorF32Dump(const Vector *v); -void vectorF64Dump(const Vector *v); +void vectorF32Dump (const Vector *v); +void vectorF64Dump (const Vector *v); +void vector1BitDump(const Vector *v); /* * Converts vector to the text representation and write the result to the sqlite3_context @@ -85283,16 +85285,10 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *); /* * Serializes vector to the blob in little-endian format according to the IEEE-754 standard */ -size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); -size_t vectorF32SerializeToBlob(const Vector *, unsigned char *, size_t); -size_t vectorF64SerializeToBlob(const Vector *, unsigned char *, size_t); - -/* - * Deserializes vector from the blob in little-endian format according to the IEEE-754 standard -*/ -size_t vectorDeserializeFromBlob (Vector *, const unsigned char *, size_t); -size_t vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); -size_t vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); +size_t vectorSerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t); +size_t vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t); /* * Calculates cosine distance between two vectors (vector must have same type and same dimensions) @@ -85301,6 +85297,11 @@ float vectorDistanceCos (const Vector *, const Vector *); float vectorF32DistanceCos (const Vector *, const Vector *); double vectorF64DistanceCos(const Vector *, const Vector *); +/* + * Calculates hamming distance between two 1-bit vectors (vector must have same dimensions) +*/ +int vector1BitDistanceHamming(const Vector *, const Vector *); + /* * Calculates L2 distance between two vectors (vector must have same type and same dimensions) */ @@ -85313,22 +85314,23 @@ double vectorF64DistanceL2(const Vector *, const Vector *); * LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob * By default, blob with even length will be treated as a f32 blob */ -void vectorSerialize (sqlite3_context *, const Vector *); -void vectorF32Serialize(sqlite3_context *, const Vector *); -void vectorF64Serialize(sqlite3_context *, const Vector *); +void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); -int vectorF32ParseSqliteBlob(sqlite3_value *, Vector *, char **); -int vectorF64ParseSqliteBlob(sqlite3_value *, Vector *, char **); -void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t); +void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); +void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); + +void vectorInitStatic(Vector *, VectorType, VectorDims, void *); void vectorInitFromBlob(Vector *, const unsigned char *, size_t); void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t); void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t); +void vectorConvert(const Vector *, Vector *); + /* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */ int detectVectorParameters(sqlite3_value *, int, int *, int *, char **); @@ -85410,10 +85412,10 @@ int nodeEdgesMetadataOffset(const DiskAnnIndex *pIndex); void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Vector *pVector); void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector); u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector); +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *distance, Vector *pVector); int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u64 nRowid); void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPruned); -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector); +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector); void nodeBinDeleteEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iDelete); void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot); @@ -85437,43 +85439,47 @@ typedef u8 MetricType; */ /* format version which can help to upgrade vector on-disk format without breaking older version of the db */ -#define VECTOR_FORMAT_PARAM_ID 1 +#define VECTOR_FORMAT_PARAM_ID 1 /* - * 1 - initial version + * 1 - v1 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u64 unused ] [u64 edge rowid] ] ... + * 2 - v2 version; node block format: [node meta] [node vector] [edge vectors] ... [ [u32 unused] [f32 distance] [u64 edge rowid] ] ... */ -#define VECTOR_FORMAT_DEFAULT 1 +#define VECTOR_FORMAT_V1 1 +#define VECTOR_FORMAT_DEFAULT 2 /* type of the vector index */ -#define VECTOR_INDEX_TYPE_PARAM_ID 2 -#define VECTOR_INDEX_TYPE_DISKANN 1 +#define VECTOR_INDEX_TYPE_PARAM_ID 2 +#define VECTOR_INDEX_TYPE_DISKANN 1 /* type of the underlying vector for the vector index */ -#define VECTOR_TYPE_PARAM_ID 3 +#define VECTOR_TYPE_PARAM_ID 3 /* dimension of the underlying vector for the vector index */ -#define VECTOR_DIM_PARAM_ID 4 +#define VECTOR_DIM_PARAM_ID 4 /* metric type used for comparing two vectors */ -#define VECTOR_METRIC_TYPE_PARAM_ID 5 -#define VECTOR_METRIC_TYPE_COS 1 -#define VECTOR_METRIC_TYPE_L2 2 +#define VECTOR_METRIC_TYPE_PARAM_ID 5 +#define VECTOR_METRIC_TYPE_COS 1 +#define VECTOR_METRIC_TYPE_L2 2 /* block size */ -#define VECTOR_BLOCK_SIZE_PARAM_ID 6 -#define VECTOR_BLOCK_SIZE_DEFAULT 128 +#define VECTOR_BLOCK_SIZE_PARAM_ID 6 +#define VECTOR_BLOCK_SIZE_DEFAULT 128 + +#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 +#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 -#define VECTOR_PRUNING_ALPHA_PARAM_ID 7 -#define VECTOR_PRUNING_ALPHA_DEFAULT 1.2 +#define VECTOR_INSERT_L_PARAM_ID 8 +#define VECTOR_INSERT_L_DEFAULT 70 -#define VECTOR_INSERT_L_PARAM_ID 8 -#define VECTOR_INSERT_L_DEFAULT 70 +#define VECTOR_SEARCH_L_PARAM_ID 9 +#define VECTOR_SEARCH_L_DEFAULT 200 -#define VECTOR_SEARCH_L_PARAM_ID 9 -#define VECTOR_SEARCH_L_DEFAULT 200 +#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 -#define VECTOR_MAX_NEIGHBORS_PARAM_ID 10 +#define VECTOR_COMPRESS_NEIGHBORS_PARAM_ID 11 /* total amount of vector index parameters */ -#define VECTOR_PARAM_IDS_COUNT 9 +#define VECTOR_PARAM_IDS_COUNT 11 /* * Vector index parameters are stored in simple binary format (1 byte tag + 8 byte u64 integer / f64 float) @@ -85555,7 +85561,7 @@ int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *); void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int); void vectorOutRowsFree(sqlite3 *, VectorOutRows *); -int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *); +int diskAnnCreateIndex(sqlite3 *, const char *, const char *, const VectorIdxKey *, VectorIdxParams *, const char **); int diskAnnClearIndex(sqlite3 *, const char *, const char *); int diskAnnDropIndex(sqlite3 *, const char *, const char *); int diskAnnOpenIndex(sqlite3 *, const char *, const char *, const VectorIdxParams *, DiskAnnIndex **); @@ -210981,6 +210987,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ return dims * sizeof(float); case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); + case VECTOR_TYPE_1BIT: + return (dims + 7) / 8; default: assert(0); } @@ -211012,10 +211020,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){ ** Note that the vector object points to the blob so if ** you free the blob, the vector becomes invalid. **/ -void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){ - pVector->type = type; +void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){ pVector->flags = VECTOR_FLAGS_STATIC; - vectorInitFromBlob(pVector, pBlob, nBlobSize); + pVector->type = type; + pVector->dims = dims; + pVector->data = pBlob; } /* @@ -211051,6 +211060,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){ return vectorF32DistanceCos(pVector1, pVector2); case VECTOR_TYPE_FLOAT64: return vectorF64DistanceCos(pVector1, pVector2); + case VECTOR_TYPE_1BIT: + return vector1BitDistanceHamming(pVector1, pVector2); default: assert(0); } @@ -211192,11 +211203,29 @@ int vectorParseSqliteBlob( Vector *pVector, char **pzErrMsg ){ + const unsigned char *pBlob; + size_t nBlobSize; + + assert( sqlite3_value_type(arg) == SQLITE_BLOB ); + + pBlob = sqlite3_value_blob(arg); + nBlobSize = sqlite3_value_bytes(arg); + if( nBlobSize % 2 == 1 ){ + nBlobSize--; + } + + if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){ + *pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize); + return SQLITE_ERROR; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; case VECTOR_TYPE_FLOAT64: - return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg); + vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + return 0; default: assert(0); } @@ -211303,6 +211332,9 @@ void vectorDump(const Vector *pVector){ case VECTOR_TYPE_FLOAT64: vectorF64Dump(pVector); break; + case VECTOR_TYPE_1BIT: + vector1BitDump(pVector); + break; default: assert(0); } @@ -211324,20 +211356,47 @@ void vectorMarshalToText( } } -void vectorSerialize( +void vectorSerializeWithType( sqlite3_context *context, const Vector *pVector ){ + unsigned char *pBlob; + size_t nBlobSize, nDataSize; + + assert( pVector->dims <= MAX_VECTOR_SZ ); + + nDataSize = vectorDataSize(pVector->type, pVector->dims); + nBlobSize = nDataSize; + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2); + } + + if( nBlobSize == 0 ){ + sqlite3_result_zeroblob(context, 0); + return; + } + + pBlob = sqlite3_malloc64(nBlobSize); + if( pBlob == NULL ){ + sqlite3_result_error_nomem(context); + return; + } + + if( pVector->type != VECTOR_TYPE_FLOAT32 ){ + pBlob[nBlobSize - 1] = pVector->type; + } + switch (pVector->type) { case VECTOR_TYPE_FLOAT32: - vectorF32Serialize(context, pVector); + vectorF32SerializeToBlob(pVector, pBlob, nDataSize); break; case VECTOR_TYPE_FLOAT64: - vectorF64Serialize(context, pVector); + vectorF64SerializeToBlob(pVector, pBlob, nDataSize); break; default: assert(0); } + sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); } size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){ @@ -211346,18 +211405,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); case VECTOR_TYPE_FLOAT64: return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize); - default: - assert(0); - } - return 0; -} - -size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){ - switch (pVector->type) { - case VECTOR_TYPE_FLOAT32: - return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize); - case VECTOR_TYPE_FLOAT64: - return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize); + case VECTOR_TYPE_1BIT: + return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize); default: assert(0); } @@ -211377,6 +211426,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo } } +void vectorConvert(const Vector *pFrom, Vector *pTo){ + int i; + u8 *bitData; + float *floatData; + + assert( pFrom->dims == pTo->dims ); + + if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){ + floatData = pFrom->data; + bitData = pTo->data; + for(i = 0; i < pFrom->dims; i += 8){ + bitData[i / 8] = 0; + } + for(i = 0; i < pFrom->dims; i++){ + if( floatData[i] > 0 ){ + bitData[i / 8] |= (1 << (i & 7)); + } + } + }else{ + assert(0); + } +} + /************************************************************************** ** SQL function implementations ****************************************************************************/ @@ -211410,7 +211482,7 @@ static void vectorFuncHintedType( sqlite3_free(pzErrMsg); goto out_free_vec; } - vectorSerialize(context, pVector); + vectorSerializeWithType(context, pVector); out_free_vec: vectorFree(pVector); } @@ -211557,6 +211629,135 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #endif /* !defined(SQLITE_OMIT_VECTOR) */ /************** End of vector.c **********************************************/ +/************** Begin file vector1bit.c **************************************/ +/* +** 2024-07-04 +** +** Copyright 2024 the libSQL authors +** +** Permission is hereby granted, free of charge, to any person obtaining a copy of +** this software and associated documentation files (the "Software"), to deal in +** the Software without restriction, including without limitation the rights to +** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +** the Software, and to permit persons to whom the Software is furnished to do so, +** subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in all +** copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +** +****************************************************************************** +** +** 1-bit vector format utilities. +*/ +#ifndef SQLITE_OMIT_VECTOR +/* #include "sqliteInt.h" */ + +/* #include "vectorInt.h" */ + +/* #include */ + +/************************************************************************** +** Utility routines for debugging +**************************************************************************/ + +void vector1BitDump(const Vector *pVec){ + u8 *elems = pVec->data; + unsigned i; + + assert( pVec->type == VECTOR_TYPE_1BIT ); + + for(i = 0; i < pVec->dims; i++){ + printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + } + printf("\n"); +} + +/************************************************************************** +** Utility routines for vector serialization and deserialization +**************************************************************************/ + +size_t vector1BitSerializeToBlob( + const Vector *pVector, + unsigned char *pBlob, + size_t nBlobSize +){ + u8 *elems = pVector->data; + u8 *pPtr = pBlob; + unsigned i; + + assert( pVector->type == VECTOR_TYPE_1BIT ); + assert( pVector->dims <= MAX_VECTOR_SZ ); + assert( nBlobSize >= (pVector->dims + 7) / 8 ); + + for(i = 0; i < (pVector->dims + 7) / 8; i++){ + pPtr[i] = elems[i]; + } + return (pVector->dims + 7) / 8; +} + +// [sum(map(int, bin(i)[2:])) for i in range(256)] +static int BitsCount[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, +}; + +static inline int sqlite3PopCount32(u32 a){ +#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER) + return __builtin_popcount(a); +#else + return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff]; +#endif +} + +int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ + int diff = 0; + u8 *e1U8 = v1->data; + u32 *e1U32 = v1->data; + u8 *e2U8 = v2->data; + u32 *e2U32 = v2->data; + int i, len8, len32, offset8; + + assert( v1->dims == v2->dims ); + assert( v1->type == VECTOR_TYPE_1BIT ); + assert( v2->type == VECTOR_TYPE_1BIT ); + + len8 = (v1->dims + 7) / 8; + len32 = v1->dims / 32; + offset8 = len32 * 4; + + for(i = 0; i < len32; i++){ + diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]); + } + for(i = offset8; i < len8; i++){ + diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]); + } + return diff; +} + +#endif /* !defined(SQLITE_OMIT_VECTOR) */ + +/************** End of vector1bit.c ******************************************/ /************** Begin file vectordiskann.c ***********************************/ /* ** 2024-03-23 @@ -211607,13 +211808,14 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ +/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ /* #include "sqliteInt.h" */ /* #include "vectorIndexInt.h" */ -#define SQLITE_VECTOR_TRACE +// #define SQLITE_VECTOR_TRACE #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) #define DiskAnnTrace(X) sqlite3DebugPrintf X; #else @@ -211639,9 +211841,18 @@ SQLITE_PRIVATE void sqlite3RegisterVectorFunctions(void){ #define VECTOR_NODE_METADATA_SIZE (sizeof(u64) + sizeof(u16)) #define VECTOR_EDGE_METADATA_SIZE (sizeof(u64) + sizeof(u64)) +typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +struct VectorPair { + int nodeType; + int edgeType; + Vector *pNode; + Vector *pEdge; +}; + // DiskAnnNode represents single node in the DiskAnn graph struct DiskAnnNode { u64 nRowid; /* node id */ @@ -211657,14 +211868,18 @@ struct DiskAnnNode { * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this) */ struct DiskAnnSearchCtx { - const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */ - DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */ - double *aDistances; /* array of distances to the query vector */ - unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ - unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ - DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ - unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ - int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ + VectorPair query; /* initial query vector; user query for SELECT and row vector for INSERT */ + DiskAnnNode **aCandidates; /* array of unvisited candidates ordered by distance (possibly approximate) to the query (ascending) */ + float *aDistances; /* array of distances (possible approximate) to the query vector */ + unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */ + unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */ + DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */ + float *aTopDistances; /* top candidates exact distances */ + int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */ + int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */ + DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */ + unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */ + int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */ }; /************************************************************************** @@ -211675,6 +211890,10 @@ static inline u16 readLE16(const unsigned char *p){ return (u16)p[0] | (u16)p[1] << 8; } +static inline u32 readLE32(const unsigned char *p){ + return (u32)p[0] | (u32)p[1] << 8 | (u32)p[2] << 16 | (u32)p[3] << 24; +} + static inline u64 readLE64(const unsigned char *p){ return (u64)p[0] | (u64)p[1] << 8 @@ -211691,6 +211910,13 @@ static inline void writeLE16(unsigned char *p, u16 v){ p[1] = v >> 8; } +static inline void writeLE32(unsigned char *p, u32 v){ + p[0] = v; + p[1] = v >> 8; + p[2] = v >> 16; + p[3] = v >> 24; +} + static inline void writeLE64(unsigned char *p, u64 v){ p[0] = v; p[1] = v >> 8; @@ -211870,7 +212096,7 @@ void nodeBinInit(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, u64 nRowid, Ve void nodeBinVector(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, Vector *pVector) { assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize <= pBlobSpot->nBufferSize ); - vectorInitStatic(pVector, pIndex->nNodeVectorType, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE, pIndex->nNodeVectorSize); + vectorInitStatic(pVector, pIndex->nNodeVectorType, pIndex->nVectorDims, pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE); } u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { @@ -211879,20 +212105,25 @@ u16 nodeBinEdges(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { return readLE16(pBlobSpot->pBuffer + sizeof(u64)); } -void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, Vector *pVector) { +void nodeBinEdge(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, int iEdge, u64 *pRowid, float *pDistance, Vector *pVector) { + u32 distance; int offset = nodeEdgesMetadataOffset(pIndex); if( pRowid != NULL ){ assert( offset + (iEdge + 1) * VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); *pRowid = readLE64(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u64)); } + if( pIndex->nFormatVersion != VECTOR_FORMAT_V1 && pDistance != NULL ){ + distance = readLE32(pBlobSpot->pBuffer + offset + iEdge * VECTOR_EDGE_METADATA_SIZE + sizeof(u32)); + *pDistance = *((float*)&distance); + } if( pVector != NULL ){ assert( VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize < offset ); vectorInitStatic( pVector, pIndex->nEdgeVectorType, - pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nNodeVectorSize, - pIndex->nEdgeVectorSize + pIndex->nVectorDims, + pBlobSpot->pBuffer + VECTOR_NODE_METADATA_SIZE + pIndex->nNodeVectorSize + iEdge * pIndex->nEdgeVectorSize ); } } @@ -211902,7 +212133,7 @@ int nodeBinEdgeFindIdx(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot, u6 // todo: if edges will be sorted by identifiers we can use binary search here (although speed up will be visible only on pretty loaded nodes: >128 edges) for(i = 0; i < nEdges; i++){ u64 edgeId; - nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL); + nodeBinEdge(pIndex, pBlobSpot, i, &edgeId, NULL, NULL); if( edgeId == nRowid ){ return i; } @@ -211917,7 +212148,7 @@ void nodeBinPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int nPru } // replace edge at position iReplace or add new one if iReplace == nEdges -void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, Vector *pVector) { +void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iReplace, u64 nRowid, float distance, Vector *pVector) { int nMaxEdges = nodeEdgesMaxCount(pIndex); int nEdges = nodeBinEdges(pIndex, pBlobSpot); int edgeVectorOffset, edgeMetaOffset, itemsToMove; @@ -211936,6 +212167,7 @@ void nodeBinReplaceEdge(const DiskAnnIndex *pIndex, BlobSpot *pBlobSpot, int iRe assert( edgeMetaOffset + VECTOR_EDGE_METADATA_SIZE <= pBlobSpot->nBufferSize ); vectorSerializeToBlob(pVector, pBlobSpot->pBuffer + edgeVectorOffset, pIndex->nEdgeVectorSize); + writeLE32(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u32), *((u32*)&distance)); writeLE64(pBlobSpot->pBuffer + edgeMetaOffset + sizeof(u64), nRowid); writeLE16(pBlobSpot->pBuffer + sizeof(u64), nEdges); @@ -211970,6 +212202,7 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { #if defined(SQLITE_DEBUG) && defined(SQLITE_VECTOR_TRACE) int nEdges, nMaxEdges, i; u64 nRowid; + float distance = 0; Vector vector; nEdges = nodeBinEdges(pIndex, pBlobSpot); @@ -211980,8 +212213,8 @@ void nodeBinDebug(const DiskAnnIndex *pIndex, const BlobSpot *pBlobSpot) { DiskAnnTrace((" nEdges=%d, nMaxEdges=%d, vector=", nEdges, nMaxEdges)); vectorDump(&vector); for(i = 0; i < nEdges; i++){ - nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &vector); - DiskAnnTrace((" to=%lld, vector=", nRowid, nRowid)); + nodeBinEdge(pIndex, pBlobSpot, i, &nRowid, &distance, &vector); + DiskAnnTrace((" to=%lld, distance=%f, vector=", nRowid, distance)); vectorDump(&vector); } #endif @@ -211996,10 +212229,11 @@ int diskAnnCreateIndex( const char *zDbSName, const char *zIdxName, const VectorIdxKey *pKey, - VectorIdxParams *pParams + VectorIdxParams *pParams, + const char **pzErrMsg ){ int rc; - int type, dims; + int type, dims, metric, neighbours; u64 maxNeighborsParam, blockSizeBytes; char *zSql; char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...) @@ -212023,24 +212257,36 @@ int diskAnnCreateIndex( } assert( 0 < dims && dims <= MAX_VECTOR_SZ ); + metric = vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID); + if( metric == 0 ){ + metric = VECTOR_METRIC_TYPE_COS; + if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, metric) != 0 ){ + return SQLITE_ERROR; + } + } + neighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( neighbours == VECTOR_TYPE_1BIT && metric != VECTOR_METRIC_TYPE_COS ){ + *pzErrMsg = "1-bit compression available only for cosine metric"; + return SQLITE_ERROR; + } + if( neighbours == 0 ){ + neighbours = type; + } + maxNeighborsParam = vectorIdxParamsGetU64(pParams, VECTOR_MAX_NEIGHBORS_PARAM_ID); if( maxNeighborsParam == 0 ){ // 3 D**(1/2) gives good recall values (90%+) // we also want to keep disk overhead at moderate level - 50x of the disk size increase is the current upper bound - maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(type, dims)) + 1); + maxNeighborsParam = MIN(3 * ((int)(sqrt(dims)) + 1), (50 * nodeOverhead(vectorDataSize(type, dims))) / nodeEdgeOverhead(vectorDataSize(neighbours, dims)) + 1); } - blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(type, dims)); + blockSizeBytes = nodeOverhead(vectorDataSize(type, dims)) + maxNeighborsParam * (u64)nodeEdgeOverhead(vectorDataSize(neighbours, dims)); if( blockSizeBytes > DISKANN_MAX_BLOCK_SZ ){ return SQLITE_ERROR; } if( vectorIdxParamsPutU64(pParams, VECTOR_BLOCK_SIZE_PARAM_ID, MAX(256, blockSizeBytes)) != 0 ){ return SQLITE_ERROR; } - if( vectorIdxParamsGetU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID) == 0 ){ - if( vectorIdxParamsPutU64(pParams, VECTOR_METRIC_TYPE_PARAM_ID, VECTOR_METRIC_TYPE_COS) != 0 ){ - return SQLITE_ERROR; - } - } + if( vectorIdxParamsGetF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID) == 0 ){ if( vectorIdxParamsPutF64(pParams, VECTOR_PRUNING_ALPHA_PARAM_ID, VECTOR_PRUNING_ALPHA_DEFAULT) != 0 ){ return SQLITE_ERROR; @@ -212355,6 +212601,83 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){ return rc; } +/************************************************************************** +** Generic utilities +**************************************************************************/ + +int initVectorPair(int nodeType, int edgeType, int dims, VectorPair *pPair){ + pPair->nodeType = nodeType; + pPair->edgeType = edgeType; + pPair->pNode = NULL; + pPair->pEdge = NULL; + if( pPair->nodeType == pPair->edgeType ){ + return 0; + } + pPair->pEdge = vectorAlloc(edgeType, dims); + if( pPair->pEdge == NULL ){ + return SQLITE_NOMEM_BKPT; + } + return 0; +} + +void loadVectorPair(VectorPair *pPair, const Vector *pVector){ + pPair->pNode = (Vector*)pVector; + if( pPair->edgeType != pPair->nodeType ){ + vectorConvert(pPair->pNode, pPair->pEdge); + }else{ + pPair->pEdge = pPair->pNode; + } +} + +void deinitVectorPair(VectorPair *pPair) { + if( pPair->pEdge != NULL && pPair->pNode != pPair->pEdge ){ + vectorFree(pPair->pEdge); + } +} + +int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){ + int i; +#ifdef SQLITE_DEBUG + for(i = 0; i < nSize - 1; i++){ + assert(aDistances[i] <= aDistances[i + 1]); + } +#endif + for(i = 0; i < nSize; i++){ + if( distance < aDistances[i] ){ + return i; + } + } + return nSize < nMaxSize ? nSize : -1; +} + +void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { + int itemsToMove; + + assert( nMaxSize > 0 && nItemSize > 0 ); + assert( nSize <= nMaxSize ); + assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize ); + + if( nSize == nMaxSize ){ + if( pLast != NULL ){ + memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize); + } + nSize--; + } + itemsToMove = nSize - iInsert; + memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize); + memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); +} + +void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { + int itemsToMove; + + assert( nItemSize > 0 ); + assert( 0 <= iDelete && iDelete < nSize ); + + itemsToMove = nSize - iDelete - 1; + memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize); +} + /************************************************************************** ** DiskANN internals **************************************************************************/ @@ -212391,16 +212714,24 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ sqlite3_free(pNode); } -static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){ - pCtx->pQuery = pQuery; +static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; pCtx->maxCandidates = maxCandidates; + pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double)); + pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*)); + pCtx->nTopCandidates = 0; + pCtx->maxTopCandidates = topCandidates; pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + goto out_oom; + } + loadVectorPair(&pCtx->query, pQuery); + + if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ goto out_oom; } return SQLITE_OK; @@ -212411,6 +212742,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un if( pCtx->aCandidates != NULL ){ sqlite3_free(pCtx->aCandidates); } + if( pCtx->aTopDistances != NULL ){ + sqlite3_free(pCtx->aTopDistances); + } + if( pCtx->aTopCandidates != NULL ){ + sqlite3_free(pCtx->aTopCandidates); + } return SQLITE_NOMEM_BKPT; } @@ -212434,6 +212771,9 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){ } sqlite3_free(pCtx->aCandidates); sqlite3_free(pCtx->aDistances); + sqlite3_free(pCtx->aTopCandidates); + sqlite3_free(pCtx->aTopDistances); + deinitVectorPair(&pCtx->query); } // check if we visited this node earlier @@ -212475,7 +212815,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const } // mark node as visited and put it in the head of visitedList -static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){ +static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){ + int iInsert; + assert( pCtx->nUnvisited > 0 ); assert( pNode->visited == 0 ); @@ -212484,56 +212826,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo pNode->pNext = pCtx->visitedList; pCtx->visitedList = pNode; + + iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance); + if( iInsert < 0 ){ + return; + } + bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); + bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){ return pCtx->nUnvisited > 0; } -static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){ +static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){ assert( 0 <= i && i < pCtx->nCandidates ); - return pCtx->aCandidates[i]; + *ppNode = pCtx->aCandidates[i]; + *pDistance = pCtx->aDistances[i]; } static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){ int i; - assert( 0 <= iDelete && iDelete < pCtx->nCandidates ); assert( pCtx->nUnvisited > 0 ); assert( !pCtx->aCandidates[iDelete]->visited ); assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); + bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); - for(i = iDelete + 1; i < pCtx->nCandidates; i++){ - pCtx->aCandidates[i - 1] = pCtx->aCandidates[i]; - pCtx->aDistances[i - 1] = pCtx->aDistances[i]; - } pCtx->nCandidates--; pCtx->nUnvisited--; } -static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){ - int i; - assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates ); - if( pCtx->nCandidates < pCtx->maxCandidates ){ - pCtx->nCandidates++; - } else { - DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1]; - if( !pLast->visited ){ - // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node - assert( pLast->pBlobSpot == NULL ); - pCtx->nUnvisited--; - diskAnnNodeFree(pLast); - } - } - // Shift the candidates to the right to make space for the new one. - for(i = pCtx->nCandidates - 1; i > iInsert; i--){ - pCtx->aCandidates[i] = pCtx->aCandidates[i - 1]; - pCtx->aDistances[i] = pCtx->aDistances[i - 1]; - } - // Insert the new candidate. - pCtx->aCandidates[iInsert] = pCandidate; - pCtx->aDistances[iInsert] = candidateDist; +static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ + DiskAnnNode *pLast = NULL; + bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); + bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); + if( pLast != NULL && !pLast->visited ){ + // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node + assert( pLast->pBlobSpot == NULL ); + pCtx->nUnvisited--; + diskAnnNodeFree(pLast); + } pCtx->nUnvisited++; } @@ -212563,7 +212900,14 @@ static int diskAnnSearchCtxFindClosestCandidateIdx(const DiskAnnSearchCtx *pCtx) // return position for new edge(C) which will replace previous edge on that position or -1 if we should ignore it // we also check that no current edge(B) will "prune" new vertex: i.e. dist(B, C) >= (means worse than) alpha * dist(node, C) for all current edges // if any edge(B) will "prune" new edge(C) we will ignore it (return -1) -static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, u64 newRowid, const Vector *pNewVector) { +static int diskAnnReplaceEdgeIdx( + const DiskAnnIndex *pIndex, + BlobSpot *pNodeBlob, + u64 newRowid, + VectorPair *pNewVector, + VectorPair *pPlaceholder, + float *pNodeToNew +) { int i, nEdges, nMaxEdges, iReplace = -1; Vector nodeVector, edgeVector; float nodeToNew, nodeToReplace; @@ -212571,20 +212915,27 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob nEdges = nodeBinEdges(pIndex, pNodeBlob); nMaxEdges = nodeEdgesMaxCount(pIndex); nodeBinVector(pIndex, pNodeBlob, &nodeVector); - nodeToNew = diskAnnVectorDistance(pIndex, &nodeVector, pNewVector); + loadVectorPair(pPlaceholder, &nodeVector); + + // we need to evaluate potentially approximate distance here in order to correctly compare it with edge distances + nodeToNew = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, pNewVector->pEdge); + *pNodeToNew = nodeToNew; for(i = nEdges - 1; i >= 0; i--){ u64 edgeRowid; float edgeToNew, nodeToEdge; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( edgeRowid == newRowid ){ // deletes can leave "zombie" edges in the graph and we must override them and not store duplicate edges in the node return i; } - edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector); - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + edgeToNew = diskAnnVectorDistance(pIndex, &edgeVector, pNewVector->pEdge); if( nodeToNew > pIndex->pruningAlpha * edgeToNew ){ return -1; } @@ -212602,12 +212953,14 @@ static int diskAnnReplaceEdgeIdx(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob // prune edges after we inserted new edge at position iInserted // we only need to check for edges which will be pruned by new vertex // no need to check for other pairs as we checked them on previous insertions -static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted) { +static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, int iInserted, VectorPair *pPlaceholder) { int i, s, nEdges; - Vector nodeVector, hintVector; + Vector nodeVector, hintEdgeVector; u64 hintRowid; nodeBinVector(pIndex, pNodeBlob, &nodeVector); + loadVectorPair(pPlaceholder, &nodeVector); + nEdges = nodeBinEdges(pIndex, pNodeBlob); assert( 0 <= iInserted && iInserted < nEdges ); @@ -212617,7 +212970,7 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i nodeBinDebug(pIndex, pNodeBlob); #endif - nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, &hintVector); + nodeBinEdge(pIndex, pNodeBlob, iInserted, &hintRowid, NULL, &hintEdgeVector); // remove edges which is no longer interesting due to the addition of iInserted i = 0; @@ -212625,14 +212978,17 @@ static void diskAnnPruneEdges(const DiskAnnIndex *pIndex, BlobSpot *pNodeBlob, i Vector edgeVector; float nodeToEdge, hintToEdge; u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, &nodeToEdge, &edgeVector); if( hintRowid == edgeRowid ){ i++; continue; } - nodeToEdge = diskAnnVectorDistance(pIndex, &nodeVector, &edgeVector); - hintToEdge = diskAnnVectorDistance(pIndex, &hintVector, &edgeVector); + if( pIndex->nFormatVersion == VECTOR_FORMAT_V1 ){ + nodeToEdge = diskAnnVectorDistance(pIndex, pPlaceholder->pEdge, &edgeVector); + } + + hintToEdge = diskAnnVectorDistance(pIndex, &hintEdgeVector, &edgeVector); if( nodeToEdge > pIndex->pruningAlpha * hintToEdge ){ nodeBinDeleteEdge(pIndex, pNodeBlob, i); nEdges--; @@ -212681,7 +213037,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u } nodeBinVector(pIndex, start->pBlobSpot, &startVector); - startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector); + startDistance = diskAnnVectorDistance(pIndex, pCtx->query.pNode, &startVector); if( pCtx->blobMode == DISKANN_BLOB_READONLY ){ assert( start->pBlobSpot != NULL ); @@ -212698,8 +213054,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u Vector vCandidate; DiskAnnNode *pCandidate; BlobSpot *pCandidateBlob; + float distance; int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx); - pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate); + diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance); rc = SQLITE_OK; if( pReusableBlobSpot != NULL ){ @@ -212727,25 +213084,30 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u goto out; } - diskAnnSearchCtxMarkVisited(pCtx, pCandidate); - nVisited += 1; DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid)); nodeBinVector(pIndex, pCandidateBlob, &vCandidate); nEdges = nodeBinEdges(pIndex, pCandidateBlob); + // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it + if( pCtx->query.pNode != pCtx->query.pEdge ){ + distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->query.pNode); + } + + diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance); + for(i = 0; i < nEdges; i++){ u64 edgeRowid; Vector edgeVector; float edgeDistance; int iInsert; DiskAnnNode *pNewCandidate; - nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, &edgeVector); + nodeBinEdge(pIndex, pCandidateBlob, i, &edgeRowid, NULL, &edgeVector); if( diskAnnSearchCtxIsVisited(pCtx, edgeRowid) || diskAnnSearchCtxHasCandidate(pCtx, edgeRowid) ){ continue; } - edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector); + edgeDistance = diskAnnVectorDistance(pIndex, pCtx->query.pEdge, &edgeVector); iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance); if( iInsert < 0 ){ continue; @@ -212822,7 +213184,7 @@ int diskAnnSearch( *pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search"); return rc; } - rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context"); goto out; @@ -212831,7 +213193,7 @@ int diskAnnSearch( if( rc != SQLITE_OK ){ goto out; } - nOutRows = MIN(k, ctx.nCandidates); + nOutRows = MIN(k, ctx.nTopCandidates); rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey)); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows"); @@ -212839,9 +213201,9 @@ int diskAnnSearch( } for(i = 0; i < nOutRows; i++){ if( pRows->aIntValues != NULL ){ - rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL); + rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL); }else{ - rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i); + rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i); } if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row"); @@ -212865,6 +213227,9 @@ int diskAnnInsert( BlobSpot *pBlobSpot = NULL; DiskAnnNode *pVisited; DiskAnnSearchCtx ctx; + VectorPair vInsert, vCandidate; + vInsert.pNode = NULL; vInsert.pEdge = NULL; + vCandidate.pNode = NULL; vCandidate.pEdge = NULL; if( pVectorInRow->pVector->dims != pIndex->nVectorDims ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): dimensions are different: %d != %d", pVectorInRow->pVector->dims, pIndex->nVectorDims); @@ -212877,12 +213242,24 @@ int diskAnnInsert( DiskAnnTrace(("diskAnnInset started\n")); - rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE); + rc = diskAnnSearchCtxInit(pIndex, &ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE); if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context"); return rc; } + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vInsert) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for node VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &vCandidate) != 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(insert): unable to allocate mem for candidate VectorPair"); + rc = SQLITE_NOMEM_BKPT; + goto out; + } + // note: we must select random row before we will insert new row in the shadow table rc = diskAnnSelectRandomShadowRow(pIndex, &nStartRowid); if( rc == SQLITE_DONE ){ @@ -212920,28 +213297,33 @@ int diskAnnInsert( } // first pass - add all visited nodes as a potential neighbours of new node for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ - Vector vector; + Vector nodeVector; int iReplace; + float nodeToNew; - nodeBinVector(pIndex, pVisited->pBlobSpot, &vector); - iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vector); + nodeBinVector(pIndex, pVisited->pBlobSpot, &nodeVector); + loadVectorPair(&vCandidate, &nodeVector); + + iReplace = diskAnnReplaceEdgeIdx(pIndex, pBlobSpot, pVisited->nRowid, &vCandidate, &vInsert, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, &vector); - diskAnnPruneEdges(pIndex, pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pBlobSpot, iReplace, pVisited->nRowid, nodeToNew, vCandidate.pEdge); + diskAnnPruneEdges(pIndex, pBlobSpot, iReplace, &vInsert); } // second pass - add new node as a potential neighbour of all visited nodes + loadVectorPair(&vInsert, pVectorInRow->pVector); for(pVisited = ctx.visitedList; pVisited != NULL; pVisited = pVisited->pNext){ int iReplace; + float nodeToNew; - iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, pVectorInRow->pVector); + iReplace = diskAnnReplaceEdgeIdx(pIndex, pVisited->pBlobSpot, nNewRowid, &vInsert, &vCandidate, &nodeToNew); if( iReplace == -1 ){ continue; } - nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, pVectorInRow->pVector); - diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace); + nodeBinReplaceEdge(pIndex, pVisited->pBlobSpot, iReplace, nNewRowid, nodeToNew, vInsert.pEdge); + diskAnnPruneEdges(pIndex, pVisited->pBlobSpot, iReplace, &vCandidate); rc = blobSpotFlush(pIndex, pVisited->pBlobSpot); if( rc != SQLITE_OK ){ @@ -212952,6 +213334,8 @@ int diskAnnInsert( rc = SQLITE_OK; out: + deinitVectorPair(&vInsert); + deinitVectorPair(&vCandidate); if( rc == SQLITE_OK ){ rc = blobSpotFlush(pIndex, pBlobSpot); if( rc != SQLITE_OK ){ @@ -213003,7 +213387,7 @@ int diskAnnDelete( nNeighbours = nodeBinEdges(pIndex, pNodeBlob); for(i = 0; i < nNeighbours; i++){ u64 edgeRowid; - nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL); + nodeBinEdge(pIndex, pNodeBlob, i, &edgeRowid, NULL, NULL); rc = blobSpotReload(pIndex, pEdgeBlob, edgeRowid, pIndex->nBlockSize); if( rc == DISKANN_ROW_NOT_FOUND ){ continue; @@ -213050,6 +213434,7 @@ int diskAnnOpenIndex( ){ DiskAnnIndex *pIndex; u64 nBlockSize; + int compressNeighbours; pIndex = sqlite3DbMallocRaw(db, sizeof(DiskAnnIndex)); if( pIndex == NULL ){ return SQLITE_NOMEM; @@ -213096,11 +213481,20 @@ int diskAnnOpenIndex( pIndex->searchL = VECTOR_SEARCH_L_DEFAULT; } pIndex->nNodeVectorSize = vectorDataSize(pIndex->nNodeVectorType, pIndex->nVectorDims); - // will change in future when we will support compression of edges vectors - pIndex->nEdgeVectorType = pIndex->nNodeVectorType; - pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + + compressNeighbours = vectorIdxParamsGetU64(pParams, VECTOR_COMPRESS_NEIGHBORS_PARAM_ID); + if( compressNeighbours == 0 ){ + pIndex->nEdgeVectorType = pIndex->nNodeVectorType; + pIndex->nEdgeVectorSize = pIndex->nNodeVectorSize; + }else if( compressNeighbours == VECTOR_TYPE_1BIT ){ + pIndex->nEdgeVectorType = compressNeighbours; + pIndex->nEdgeVectorSize = vectorDataSize(compressNeighbours, pIndex->nVectorDims); + }else{ + return SQLITE_ERROR; + } *ppIndex = pIndex; + DiskAnnTrace(("opened index %s: max edges %d\n", zIdxName, nodeEdgesMaxCount(pIndex))); return SQLITE_OK; } @@ -213216,26 +213610,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -size_t vectorF32DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - float *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT32; - pVector->dims = nBlobSize / sizeof(float); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 0 || pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT32 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF32(pBlob); - pBlob += sizeof(float); - } - return vectorDataSize(pVector->type, pVector->dims); -} - void vectorF32Serialize( sqlite3_context *context, const Vector *pVector @@ -213342,32 +213716,22 @@ void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF32ParseSqliteBlob( - sqlite3_value *arg, +void vectorF32DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; float *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT32 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(float) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f32 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(float) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF32(pBlob); pBlob += sizeof(float); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -213474,57 +213838,6 @@ size_t vectorF64SerializeToBlob( return sizeof(double) * pVector->dims; } -size_t vectorF64DeserializeFromBlob( - Vector *pVector, - const unsigned char *pBlob, - size_t nBlobSize -){ - double *elems = pVector->data; - unsigned i; - pVector->type = VECTOR_TYPE_FLOAT64; - pVector->dims = nBlobSize / sizeof(double); - - assert( pVector->dims <= MAX_VECTOR_SZ ); - assert( nBlobSize % 2 == 1 && pBlob[nBlobSize - 1] == VECTOR_TYPE_FLOAT64 ); - - for(i = 0; i < pVector->dims; i++){ - elems[i] = deserializeF64(pBlob); - pBlob += sizeof(double); - } - return vectorDataSize(pVector->type, pVector->dims); -} - -void vectorF64Serialize( - sqlite3_context *context, - const Vector *pVector -){ - double *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT64 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - // allocate one extra trailing byte with vector blob type metadata - nBlobSize = vectorDataSize(pVector->type, pVector->dims) + 1; - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF64SerializeToBlob(pVector, pBlob, nBlobSize - 1); - pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64; - - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_DOUBLE_CHAR_LIMIT 32 void vectorF64MarshalToText( sqlite3_context *context, @@ -213603,32 +213916,22 @@ void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t n pVector->data = (void*)pBlob; } -int vectorF64ParseSqliteBlob( - sqlite3_value *arg, +void vectorF64DeserializeFromBlob( Vector *pVector, - char **pzErr + const unsigned char *pBlob, + size_t nBlobSize ){ - const unsigned char *pBlob; double *elems = pVector->data; unsigned i; assert( pVector->type == VECTOR_TYPE_FLOAT64 ); assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ ); - assert( sqlite3_value_type(arg) == SQLITE_BLOB ); - - pBlob = sqlite3_value_blob(arg); - if( sqlite3_value_bytes(arg) < sizeof(double) * pVector->dims ){ - *pzErr = sqlite3_mprintf("invalid f64 vector: not enough bytes for all dimensions"); - goto error; - } + assert( nBlobSize >= pVector->dims * sizeof(double) ); for(i = 0; i < pVector->dims; i++){ elems[i] = deserializeF64(pBlob); pBlob += sizeof(double); } - return 0; -error: - return -1; } #endif /* !defined(SQLITE_OMIT_VECTOR) */ @@ -214033,13 +214336,14 @@ struct VectorParamName { }; static struct VectorParamName VECTOR_PARAM_NAMES[] = { - { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, - { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, - { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, - { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, - { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, - { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, + { "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS }, + { "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 }, + { "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT }, + { "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 }, + { "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 }, + { "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 }, + { "max_neighbors", VECTOR_MAX_NEIGHBORS_PARAM_ID, 1, 0, 0 }, }; static int parseVectorIdxParam(const char *zParam, VectorIdxParams *pParams, const char **pErrMsg) { @@ -214439,7 +214743,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co int i, rc = SQLITE_OK; int dims, type; int hasLibsqlVectorIdxFn = 0, hasCollation = 0; - const char *pzErrMsg; + const char *pzErrMsg = NULL; assert( zDbSName != NULL ); @@ -214551,9 +214855,13 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: unsupported for tables without ROWID and composite primary key"); return CREATE_FAIL; } - rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams); + rc = diskAnnCreateIndex(db, zDbSName, pIdx->zName, &idxKey, &idxParams, &pzErrMsg); if( rc != SQLITE_OK ){ - sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + if( pzErrMsg != NULL ){ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann: %s", pzErrMsg); + }else{ + sqlite3ErrorMsg(pParse, "vector index: unable to initialize diskann"); + } return CREATE_FAIL; } rc = insertIndexParameters(db, zDbSName, pIdx->zName, &idxParams); From 8441108e713d2aa67606cf91385fe82c066045d7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 18:31:38 +0400 Subject: [PATCH 48/70] allow vector index to be partial --- libsql-sqlite3/src/vectorIndex.c | 5 --- libsql-sqlite3/test/libsql_vector_index.test | 33 ++++++++++++++++++-- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index d8b3497781..001a1aae10 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -862,11 +862,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index c1a270e4da..a173c773d3 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -275,6 +275,36 @@ do_execsql_test vector-all-params { SELECT * FROM vector_top_k('t_all_params_idx', vector('[1,2]'), 2); } {1 2} +do_execsql_test vector-partial { + CREATE TABLE t_partial( name TEXT, type INT, v FLOAT32(3)); + INSERT INTO t_partial VALUES ( 'a', 0, vector('[1,2,3]') ); + INSERT INTO t_partial VALUES ( 'b', 1, vector('[3,4,5]') ); + INSERT INTO t_partial VALUES ( 'c', 2, vector('[4,5,6]') ); + INSERT INTO t_partial VALUES ( 'd', 0, vector('[5,6,7]') ); + INSERT INTO t_partial VALUES ( 'e', 1, vector('[6,7,8]') ); + INSERT INTO t_partial VALUES ( 'f', 2, vector('[7,8,9]') ); + CREATE INDEX t_partial_idx_0 ON t_partial( libsql_vector_idx(v) ) WHERE type = 0; + CREATE INDEX t_partial_idx_1 ON t_partial( libsql_vector_idx(v) ) WHERE type = 1; + CREATE INDEX t_partial_idx_not_0 ON t_partial( libsql_vector_idx(v) ) WHERE type != 0; + SELECT id FROM vector_top_k('t_partial_idx_0', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_1', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_not_0', vector('[1,2,3]'), 10); + INSERT INTO t_partial VALUES ( 'g', 0, vector('[8,9,10]') ); + INSERT INTO t_partial VALUES ( 'h', 1, vector('[9,10,11]') ); + INSERT INTO t_partial VALUES ( 'i', 2, vector('[10,11,12]') ); + SELECT id FROM vector_top_k('t_partial_idx_0', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_1', vector('[1,2,3]'), 10); + SELECT id FROM vector_top_k('t_partial_idx_not_0', vector('[1,2,3]'), 10); +} { + 1 4 + 2 5 + 2 3 5 6 + + 1 4 7 + 2 5 8 + 2 3 5 6 8 9 +} + proc error_messages {sql} { set ret "" catch { @@ -309,8 +339,6 @@ do_test vector-errors { sqlite3_exec db { CREATE TABLE t_mixed_t( v FLOAT32(3)); } sqlite3_exec db { INSERT INTO t_mixed_t VALUES('[1]'); } lappend ret [error_messages {CREATE INDEX t_mixed_t_idx ON t_mixed_t( libsql_vector_idx(v) )}] - sqlite3_exec db { CREATE TABLE t_partial( name TEXT, type INT, v FLOAT32(3)); } - lappend ret [error_messages {CREATE INDEX t_partial_idx ON t_partial( libsql_vector_idx(v) ) WHERE type = 0}] } [list {*}{ {no such table: main.t_no} {no such column: v} @@ -328,5 +356,4 @@ do_test vector-errors { {vector index(insert): only f32 vectors are supported} {vector index(search): dimensions are different: 2 != 4} {vector index(insert): dimensions are different: 1 != 3} - {vector index: where condition is forbidden} }] From ec996fabf462d841b9256f032c102db7d2a95e79 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Thu, 8 Aug 2024 19:40:55 +0400 Subject: [PATCH 49/70] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 5 ----- libsql-ffi/bundled/src/sqlite3.c | 5 ----- 2 files changed, 10 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 15d09606fb..3a76f9cff3 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -214523,11 +214523,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 15d09606fb..3a76f9cff3 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -214523,11 +214523,6 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co sqlite3ErrorMsg(pParse, "vector index: must contain exactly one column wrapped into the " VECTOR_INDEX_MARKER_FUNCTION " function"); return CREATE_FAIL; } - // we are able to support this but I doubt this works for now - more polishing required to make this work - if( pIdx->pPartIdxWhere != NULL ) { - sqlite3ErrorMsg(pParse, "vector index: where condition is forbidden"); - return CREATE_FAIL; - } pArgsList = pIdx->aColExpr->a[0].pExpr->x.pList; pListItem = pArgsList->a; From ac6a89bf87e484d2acfffdb5a68753d50e9b5e6b Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 9 Aug 2024 13:55:56 +0400 Subject: [PATCH 50/70] small fixes --- libsql-sqlite3/src/vector.c | 15 +++++++------- libsql-sqlite3/src/vector1bit.c | 5 +++-- libsql-sqlite3/src/vectorIndex.c | 4 ++-- libsql-sqlite3/src/vectorInt.h | 4 ++-- libsql-sqlite3/src/vectordiskann.c | 13 ++++++------ libsql-sqlite3/src/vectorfloat32.c | 33 +++--------------------------- libsql-sqlite3/src/vectorfloat64.c | 8 ++++++-- 7 files changed, 30 insertions(+), 52 deletions(-) diff --git a/libsql-sqlite3/src/vector.c b/libsql-sqlite3/src/vector.c index 73b84b047e..c622d977e1 100644 --- a/libsql-sqlite3/src/vector.c +++ b/libsql-sqlite3/src/vector.c @@ -42,6 +42,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -252,7 +253,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -362,14 +363,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -531,7 +532,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -581,7 +582,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -636,12 +637,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; diff --git a/libsql-sqlite3/src/vector1bit.c b/libsql-sqlite3/src/vector1bit.c index 66da59f76a..f4fd5f9100 100644 --- a/libsql-sqlite3/src/vector1bit.c +++ b/libsql-sqlite3/src/vector1bit.c @@ -41,10 +41,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index 5a13b4ea60..983e94cb9f 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -261,7 +261,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -982,7 +982,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } diff --git a/libsql-sqlite3/src/vectorInt.h b/libsql-sqlite3/src/vectorInt.h index 84cf9c0d1f..efe8f3cf38 100644 --- a/libsql-sqlite3/src/vectorInt.h +++ b/libsql-sqlite3/src/vectorInt.h @@ -42,7 +42,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -97,7 +97,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 88151bf1df..ae832f4400 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -47,7 +47,6 @@ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -#include "vectorInt.h" #ifndef SQLITE_OMIT_VECTOR #include "math.h" @@ -84,7 +83,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -966,15 +966,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -987,6 +985,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } diff --git a/libsql-sqlite3/src/vectorfloat32.c b/libsql-sqlite3/src/vectorfloat32.c index 5d6641991c..d53d10d593 100644 --- a/libsql-sqlite3/src/vectorfloat32.c +++ b/libsql-sqlite3/src/vectorfloat32.c @@ -41,10 +41,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -94,34 +95,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, diff --git a/libsql-sqlite3/src/vectorfloat64.c b/libsql-sqlite3/src/vectorfloat64.c index 1d29c9c3d6..885306c8c6 100644 --- a/libsql-sqlite3/src/vectorfloat64.c +++ b/libsql-sqlite3/src/vectorfloat64.c @@ -38,10 +38,14 @@ void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** From 2aca12b6e13c0d9da3e10afede0842cf38917dc3 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Fri, 9 Aug 2024 13:57:36 +0400 Subject: [PATCH 51/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 82 +++++++------------ libsql-ffi/bundled/src/sqlite3.c | 82 +++++++------------ 2 files changed, 60 insertions(+), 104 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index eff02cf87a..6c60bccaab 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -85264,7 +85264,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -85319,7 +85319,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -210988,6 +210988,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211198,7 +211199,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -211308,14 +211309,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -211477,7 +211478,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -211527,7 +211528,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211582,12 +211583,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211673,10 +211674,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -211808,7 +211810,6 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ @@ -211845,7 +211846,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -212727,15 +212729,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -212748,6 +212748,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } @@ -213557,10 +213558,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -213610,34 +213612,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, @@ -213778,10 +213752,14 @@ void vectorF32DeserializeFromBlob( void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -214201,7 +214179,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -214922,7 +214900,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index eff02cf87a..6c60bccaab 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -85264,7 +85264,7 @@ struct Vector { size_t vectorDataSize(VectorType, VectorDims); Vector *vectorAlloc(VectorType, VectorDims); void vectorFree(Vector *v); -int vectorParse(sqlite3_value *, Vector *, char **); +int vectorParseWithType(sqlite3_value *, Vector *, char **); void vectorInit(Vector *, VectorType, VectorDims, void *); /* @@ -85319,7 +85319,7 @@ void vectorSerializeWithType(sqlite3_context *, const Vector *); /* * Parses Vector content from the blob; vector type and dimensions must be filled already */ -int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **); +int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **); void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t); void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t); @@ -210988,6 +210988,7 @@ size_t vectorDataSize(VectorType type, VectorDims dims){ case VECTOR_TYPE_FLOAT64: return dims * sizeof(double); case VECTOR_TYPE_1BIT: + assert( dims > 0 ); return (dims + 7) / 8; default: assert(0); @@ -211198,7 +211199,7 @@ static int vectorParseSqliteText( return -1; } -int vectorParseSqliteBlob( +int vectorParseSqliteBlobWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg @@ -211308,14 +211309,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD } } -int vectorParse( +int vectorParseWithType( sqlite3_value *arg, Vector *pVector, char **pzErrMsg ){ switch( sqlite3_value_type(arg) ){ case SQLITE_BLOB: - return vectorParseSqliteBlob(arg, pVector, pzErrMsg); + return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg); case SQLITE_TEXT: return vectorParseSqliteText(arg, pVector, pzErrMsg); default: @@ -211477,7 +211478,7 @@ static void vectorFuncHintedType( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free_vec; @@ -211527,7 +211528,7 @@ static void vectorExtractFunc( if( pVector==NULL ){ return; } - if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211582,12 +211583,12 @@ static void vectorDistanceCosFunc( if( pVector2==NULL ){ goto out_free; } - if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; } - if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){ + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){ sqlite3_result_error(context, pzErrMsg, -1); sqlite3_free(pzErrMsg); goto out_free; @@ -211673,10 +211674,11 @@ void vector1BitDump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_1BIT ); + printf("f1bit: ["); for(i = 0; i < pVec->dims; i++){ - printf("%d ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); + printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -211808,7 +211810,6 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){ ** diskAnnInsert() Insert single new(!) vector in an opened index ** diskAnnDelete() Delete row by key from an opened index */ -/* #include "vectorInt.h" */ #ifndef SQLITE_OMIT_VECTOR /* #include "math.h" */ @@ -211845,7 +211846,8 @@ typedef struct VectorPair VectorPair; typedef struct DiskAnnSearchCtx DiskAnnSearchCtx; typedef struct DiskAnnNode DiskAnnNode; -// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation (always NULL if pNodeType == pEdgeType) +// VectorPair represents single vector where pNode is an exact representation and pEdge - compressed representation +// (pEdge pointer always equals to pNode if pNodeType == pEdgeType) struct VectorPair { int nodeType; int edgeType; @@ -212727,15 +212729,13 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - goto out_oom; + return SQLITE_NOMEM_BKPT; } loadVectorPair(&pCtx->query, pQuery); - if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){ - goto out_oom; + if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ + return SQLITE_OK; } - return SQLITE_OK; -out_oom: if( pCtx->aDistances != NULL ){ sqlite3_free(pCtx->aDistances); } @@ -212748,6 +212748,7 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC if( pCtx->aTopCandidates != NULL ){ sqlite3_free(pCtx->aTopCandidates); } + deinitVectorPair(&pCtx->query); return SQLITE_NOMEM_BKPT; } @@ -213557,10 +213558,11 @@ void vectorF32Dump(const Vector *pVec){ assert( pVec->type == VECTOR_TYPE_FLOAT32 ); + printf("f32: ["); for(i = 0; i < pVec->dims; i++){ - printf("%f ", elems[i]); + printf("%s%f", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -213610,34 +213612,6 @@ size_t vectorF32SerializeToBlob( return sizeof(float) * pVector->dims; } -void vectorF32Serialize( - sqlite3_context *context, - const Vector *pVector -){ - float *elems = pVector->data; - unsigned char *pBlob; - size_t nBlobSize; - - assert( pVector->type == VECTOR_TYPE_FLOAT32 ); - assert( pVector->dims <= MAX_VECTOR_SZ ); - - nBlobSize = vectorDataSize(pVector->type, pVector->dims); - - if( nBlobSize == 0 ){ - sqlite3_result_zeroblob(context, 0); - return; - } - - pBlob = sqlite3_malloc64(nBlobSize); - if( pBlob == NULL ){ - sqlite3_result_error_nomem(context); - return; - } - - vectorF32SerializeToBlob(pVector, pBlob, nBlobSize); - sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free); -} - #define SINGLE_FLOAT_CHAR_LIMIT 32 void vectorF32MarshalToText( sqlite3_context *context, @@ -213778,10 +213752,14 @@ void vectorF32DeserializeFromBlob( void vectorF64Dump(const Vector *pVec){ double *elems = pVec->data; unsigned i; + + assert( pVec->type == VECTOR_TYPE_FLOAT64 ); + + printf("f64: ["); for(i = 0; i < pVec->dims; i++){ - printf("%lf ", elems[i]); + printf("%s%lf", i == 0 ? "" : ", ", elems[i]); } - printf("\n"); + printf("]\n"); } /************************************************************************** @@ -214201,7 +214179,7 @@ int vectorInRowAlloc(sqlite3 *db, const UnpackedRecord *pRecord, VectorInRow *pV vectorInitFromBlob(pVectorInRow->pVector, sqlite3_value_blob(pVectorValue), sqlite3_value_bytes(pVectorValue)); } else if( sqlite3_value_type(pVectorValue) == SQLITE_TEXT ){ // users can put strings (e.g. '[1,2,3]') in the table and we should process them correctly - if( vectorParse(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(pVectorValue, pVectorInRow->pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } @@ -214922,7 +214900,7 @@ int vectorIndexSearch( rc = SQLITE_NOMEM_BKPT; goto out; } - if( vectorParse(argv[1], pVector, pzErrMsg) != 0 ){ + if( vectorParseWithType(argv[1], pVector, pzErrMsg) != 0 ){ rc = SQLITE_ERROR; goto out; } From 7fe8f96964711dec9c724c32abb5c6665a3c3822 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 2 Aug 2024 13:23:10 -0700 Subject: [PATCH 52/70] c: add replicated data for sync This adds a new replicated struct that is output during sync for the C bindings. --- bindings/c/include/libsql.h | 7 ++++++- bindings/c/src/lib.rs | 9 +++++++-- bindings/c/src/types.rs | 6 ++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/bindings/c/include/libsql.h b/bindings/c/include/libsql.h index 8178980466..c1285d4fff 100644 --- a/bindings/c/include/libsql.h +++ b/bindings/c/include/libsql.h @@ -27,6 +27,11 @@ typedef struct libsql_stmt libsql_stmt; typedef const libsql_database *libsql_database_t; +typedef struct { + uintptr_t frame_no; + uintptr_t frames_synced; +} replicated; + typedef struct { const char *db_path; const char *primary_url; @@ -56,7 +61,7 @@ typedef struct { extern "C" { #endif // __cplusplus -int libsql_sync(libsql_database_t db, const char **out_err_msg); +int libsql_sync(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); int libsql_open_sync(const char *db_path, const char *primary_url, diff --git a/bindings/c/src/lib.rs b/bindings/c/src/lib.rs index 96e4effd2a..3376c4e78b 100644 --- a/bindings/c/src/lib.rs +++ b/bindings/c/src/lib.rs @@ -11,7 +11,7 @@ use tokio::runtime::Runtime; use types::{ blob, libsql_connection, libsql_connection_t, libsql_database, libsql_database_t, libsql_row, libsql_row_t, libsql_rows, libsql_rows_future_t, libsql_rows_t, libsql_stmt, libsql_stmt_t, - stmt, + replicated, stmt, }; lazy_static! { @@ -34,11 +34,16 @@ unsafe fn set_err_msg(msg: String, output: *mut *const std::ffi::c_char) { #[no_mangle] pub unsafe extern "C" fn libsql_sync( db: libsql_database_t, + out_replicated: *mut replicated, out_err_msg: *mut *const std::ffi::c_char, ) -> std::ffi::c_int { let db = db.get_ref(); match RT.block_on(db.sync()) { - Ok(_) => 0, + Ok(replicated) => { + (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as usize; + (*out_replicated).frames_synced = replicated.frames_synced() as usize; + 0 + } Err(e) => { set_err_msg(format!("Error syncing database: {e}"), out_err_msg); 1 diff --git a/bindings/c/src/types.rs b/bindings/c/src/types.rs index 9f818e28d4..2ec399973a 100644 --- a/bindings/c/src/types.rs +++ b/bindings/c/src/types.rs @@ -115,6 +115,12 @@ impl From<&mut libsql_connection> for libsql_connection_t { } } +#[repr(C)] +pub struct replicated { + pub frame_no: usize, + pub frames_synced: usize, +} + pub struct stmt { pub stmt: libsql::Statement, pub params: Vec, From fb852624188cd78babfc1463f631dd870e9b826c Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 6 Aug 2024 15:02:25 -0400 Subject: [PATCH 53/70] c: rename sync method to to avoid breaking change --- bindings/c/include/libsql.h | 8 +++++--- bindings/c/src/lib.rs | 22 ++++++++++++++++++++-- bindings/c/src/types.rs | 4 ++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/bindings/c/include/libsql.h b/bindings/c/include/libsql.h index c1285d4fff..7fdfb4b3c0 100644 --- a/bindings/c/include/libsql.h +++ b/bindings/c/include/libsql.h @@ -28,8 +28,8 @@ typedef struct libsql_stmt libsql_stmt; typedef const libsql_database *libsql_database_t; typedef struct { - uintptr_t frame_no; - uintptr_t frames_synced; + int frame_no; + int frames_synced; } replicated; typedef struct { @@ -61,7 +61,9 @@ typedef struct { extern "C" { #endif // __cplusplus -int libsql_sync(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); +int libsql_sync(libsql_database_t db, const char **out_err_msg); + +int libsql_sync2(libsql_database_t db, replicated *out_replicated, const char **out_err_msg); int libsql_open_sync(const char *db_path, const char *primary_url, diff --git a/bindings/c/src/lib.rs b/bindings/c/src/lib.rs index 3376c4e78b..6cb1dc096e 100644 --- a/bindings/c/src/lib.rs +++ b/bindings/c/src/lib.rs @@ -33,6 +33,21 @@ unsafe fn set_err_msg(msg: String, output: *mut *const std::ffi::c_char) { #[no_mangle] pub unsafe extern "C" fn libsql_sync( + db: libsql_database_t, + out_err_msg: *mut *const std::ffi::c_char, +) -> std::ffi::c_int { + let db = db.get_ref(); + match RT.block_on(db.sync()) { + Ok(_) => 0, + Err(e) => { + set_err_msg(format!("Error syncing database: {e}"), out_err_msg); + 1 + } + } +} + +#[no_mangle] +pub unsafe extern "C" fn libsql_sync2( db: libsql_database_t, out_replicated: *mut replicated, out_err_msg: *mut *const std::ffi::c_char, @@ -40,8 +55,11 @@ pub unsafe extern "C" fn libsql_sync( let db = db.get_ref(); match RT.block_on(db.sync()) { Ok(replicated) => { - (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as usize; - (*out_replicated).frames_synced = replicated.frames_synced() as usize; + if !out_replicated.is_null() { + (*out_replicated).frame_no = replicated.frame_no().unwrap_or(0) as i32; + (*out_replicated).frames_synced = replicated.frames_synced() as i32; + } + 0 } Err(e) => { diff --git a/bindings/c/src/types.rs b/bindings/c/src/types.rs index 2ec399973a..5d9f0b517f 100644 --- a/bindings/c/src/types.rs +++ b/bindings/c/src/types.rs @@ -117,8 +117,8 @@ impl From<&mut libsql_connection> for libsql_connection_t { #[repr(C)] pub struct replicated { - pub frame_no: usize, - pub frames_synced: usize, + pub frame_no: std::ffi::c_int, + pub frames_synced: std::ffi::c_int, } pub struct stmt { From 7c4ea18c75598e7cd2b84d2940d9c6a0f380d571 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 10:22:19 +0200 Subject: [PATCH 54/70] abstract replicator injector and introduce SqliteInjector --- bottomless/src/replicator.rs | 9 +- libsql-replication/src/injector/error.rs | 2 + libsql-replication/src/injector/mod.rs | 298 +-------------- .../injector/{ => sqlite_injector}/headers.rs | 0 .../{ => sqlite_injector}/injector_wal.rs | 0 .../src/injector/sqlite_injector/mod.rs | 345 ++++++++++++++++++ libsql-replication/src/replicator.rs | 95 ++--- libsql/src/replication/mod.rs | 7 +- 8 files changed, 418 insertions(+), 338 deletions(-) rename libsql-replication/src/injector/{ => sqlite_injector}/headers.rs (100%) rename libsql-replication/src/injector/{ => sqlite_injector}/injector_wal.rs (100%) create mode 100644 libsql-replication/src/injector/sqlite_injector/mod.rs diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index f2ef812f75..26e190df66 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -17,6 +17,7 @@ use aws_sdk_s3::primitives::ByteStream; use aws_sdk_s3::{Client, Config}; use bytes::{Buf, Bytes}; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; +use libsql_replication::injector::Injector as _; use libsql_sys::{Cipher, EncryptionConfig}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -1449,12 +1450,12 @@ impl Replicator { db_path: &Path, ) -> Result { let encryption_config = self.encryption_config.clone(); - let mut injector = libsql_replication::injector::Injector::new( - db_path, + let mut injector = libsql_replication::injector::SqliteInjector::new( + db_path.to_path_buf(), 4096, libsql_sys::connection::NO_AUTOCHECKPOINT, encryption_config, - )?; + ).await?; let prefix = format!("{}-{}/", self.db_name, generation); let mut page_buf = { let mut v = Vec::with_capacity(page_size); @@ -1552,7 +1553,7 @@ impl Replicator { }, page_buf.as_slice(), ); - injector.inject_frame(frame_to_inject)?; + injector.inject_frame(frame_to_inject).await?; applied_wal_frame = true; } } diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index 14899089ea..b1cebfe28b 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,3 +1,5 @@ +pub type Result = std::result::Result; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("IO error: {0}")] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 80443964fe..1d69ae0aab 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,299 +1,27 @@ -use std::path::Path; -use std::sync::Arc; -use std::{collections::VecDeque, path::PathBuf}; +use std::future::Future; -use parking_lot::Mutex; -use rusqlite::OpenFlags; +pub use sqlite_injector::SqliteInjector; use crate::frame::{Frame, FrameNo}; +use error::Result; pub use error::Error; -use self::injector_wal::{ - InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, -}; - mod error; -mod headers; -mod injector_wal; - -#[derive(Debug)] -pub enum InjectError {} - -pub type FrameBuffer = Arc>>; - -pub struct Injector { - /// The injector is in a transaction state - is_txn: bool, - /// Buffer for holding current transaction frames - buffer: FrameBuffer, - /// Maximum capacity of the frame buffer - capacity: usize, - /// Injector connection - // connection must be dropped before the hook context - connection: Arc>>, - biggest_uncommitted_seen: FrameNo, - - // Connection config items used to recreate the injection connection - path: PathBuf, - encryption_config: Option, - auto_checkpoint: u32, -} - -/// Methods from this trait are called before and after performing a frame injection. -/// This trait trait is used to record the last committed frame_no to the log. -/// The implementer can persist the pre and post commit frame no, and compare them in the event of -/// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -impl Injector { - pub fn new( - path: impl AsRef, - capacity: usize, - auto_checkpoint: u32, - encryption_config: Option, - ) -> Result { - let path = path.as_ref().to_path_buf(); - - let buffer = FrameBuffer::default(); - let wal_manager = InjectorWalManager::new(buffer.clone()); - let connection = libsql_sys::Connection::open( - &path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_CREATE - | OpenFlags::SQLITE_OPEN_URI - | OpenFlags::SQLITE_OPEN_NO_MUTEX, - wal_manager, - auto_checkpoint, - encryption_config.clone(), - )?; - - Ok(Self { - is_txn: false, - buffer, - capacity, - connection: Arc::new(Mutex::new(connection)), - biggest_uncommitted_seen: 0, - - path, - encryption_config, - auto_checkpoint, - }) - } - - /// Inject a frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). - pub fn inject_frame(&mut self, frame: Frame) -> Result, Error> { - let frame_close_txn = frame.header().size_after.get() != 0; - self.buffer.lock().push_back(frame); - if frame_close_txn || self.buffer.lock().len() >= self.capacity { - return self.flush(); - } +mod sqlite_injector; - Ok(None) - } +pub trait Injector { + /// Inject a singular frame. + fn inject_frame( + &mut self, + frame: Frame, + ) -> impl Future>> + Send; - pub fn rollback(&mut self) { - let conn = self.connection.lock(); - let mut rollback = conn.prepare_cached("ROLLBACK").unwrap(); - let _ = rollback.execute(()); - self.is_txn = false; - } + /// Discard any uncommintted frames. + fn rollback(&mut self) -> impl Future + Send; /// Flush the buffer to libsql WAL. /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame /// are then injected into the wal. - pub fn flush(&mut self) -> Result, Error> { - match self.try_flush() { - Err(e) => { - // something went wrong, rollback the connection to make sure we can retry in a - // clean state - self.biggest_uncommitted_seen = 0; - self.rollback(); - Err(e) - } - Ok(ret) => Ok(ret), - } - } - - fn try_flush(&mut self) -> Result, Error> { - if !self.is_txn { - self.begin_txn()?; - } - - let lock = self.buffer.lock(); - // the frames in the buffer are either monotonically increasing (log) or decreasing - // (snapshot). Either way, we want to find the biggest frameno we're about to commit, and - // that is either the front or the back of the buffer - let last_frame_no = match lock.back().zip(lock.front()) { - Some((b, f)) => f.header().frame_no.get().max(b.header().frame_no.get()), - None => { - tracing::trace!("nothing to inject"); - return Ok(None); - } - }; - - self.biggest_uncommitted_seen = self.biggest_uncommitted_seen.max(last_frame_no); - - drop(lock); - - let connection = self.connection.lock(); - // use prepare cached to avoid parsing the same statement over and over again. - let mut stmt = - connection.prepare_cached("INSERT INTO libsql_temp_injection VALUES (42)")?; - - // We execute the statement, and then force a call to xframe if necesacary. If the execute - // succeeds, then xframe wasn't called, in this case, we call cache_flush, and then process - // the error. - // It is unexpected that execute flushes, but it is possible, so we handle that case. - match stmt.execute(()).and_then(|_| connection.cache_flush()) { - Ok(_) => panic!("replication hook was not called"), - Err(e) => { - if let Some(e) = e.sqlite_error() { - if e.extended_code == LIBSQL_INJECT_OK { - // refresh schema - connection.pragma_update(None, "writable_schema", "reset")?; - let mut rollback = connection.prepare_cached("ROLLBACK")?; - let _ = rollback.execute(()); - self.is_txn = false; - assert!(self.buffer.lock().is_empty()); - let commit_frame_no = self.biggest_uncommitted_seen; - self.biggest_uncommitted_seen = 0; - return Ok(Some(commit_frame_no)); - } else if e.extended_code == LIBSQL_INJECT_OK_TXN { - self.is_txn = true; - assert!(self.buffer.lock().is_empty()); - return Ok(None); - } else if e.extended_code == LIBSQL_INJECT_FATAL { - return Err(Error::FatalInjectError); - } - } - - Err(Error::FatalInjectError) - } - } - } - - fn begin_txn(&mut self) -> Result<(), Error> { - let mut conn = self.connection.lock(); - - { - let wal_manager = InjectorWalManager::new(self.buffer.clone()); - let new_conn = libsql_sys::Connection::open( - &self.path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_CREATE - | OpenFlags::SQLITE_OPEN_URI - | OpenFlags::SQLITE_OPEN_NO_MUTEX, - wal_manager, - self.auto_checkpoint, - self.encryption_config.clone(), - )?; - - let _ = std::mem::replace(&mut *conn, new_conn); - } - - conn.pragma_update(None, "writable_schema", "true")?; - - let mut stmt = conn.prepare_cached("BEGIN IMMEDIATE")?; - stmt.execute(())?; - // we create a dummy table. This table MUST not be persisted, otherwise the replica schema - // would differ with the primary's. - let mut stmt = - conn.prepare_cached("CREATE TABLE IF NOT EXISTS libsql_temp_injection (x)")?; - stmt.execute(())?; - - Ok(()) - } - - pub fn clear_buffer(&mut self) { - self.buffer.lock().clear() - } - - #[cfg(test)] - pub fn is_txn(&self) -> bool { - self.is_txn - } -} - -#[cfg(test)] -mod test { - use crate::frame::FrameBorrowed; - use std::mem::size_of; - - use super::*; - /// this this is generated by creating a table test, inserting 5 rows into it, and then - /// truncating the wal file of it's header. - const WAL: &[u8] = include_bytes!("../../assets/test/test_wallog"); - - fn wal_log() -> impl Iterator { - WAL.chunks(size_of::()) - .map(|b| Frame::try_from(b).unwrap()) - } - - #[test] - fn test_simple_inject_frames() { - let temp = tempfile::tempdir().unwrap(); - - let mut injector = Injector::new(temp.path().join("data"), 10, 10000, None).unwrap(); - let log = wal_log(); - for frame in log { - injector.inject_frame(frame).unwrap(); - } - - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - - conn.query_row("SELECT COUNT(*) FROM test", (), |row| { - assert_eq!(row.get::<_, usize>(0).unwrap(), 5); - Ok(()) - }) - .unwrap(); - } - - #[test] - fn test_inject_frames_split_txn() { - let temp = tempfile::tempdir().unwrap(); - - // inject one frame at a time - let mut injector = Injector::new(temp.path().join("data"), 1, 10000, None).unwrap(); - let log = wal_log(); - for frame in log { - injector.inject_frame(frame).unwrap(); - } - - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - - conn.query_row("SELECT COUNT(*) FROM test", (), |row| { - assert_eq!(row.get::<_, usize>(0).unwrap(), 5); - Ok(()) - }) - .unwrap(); - } - - #[test] - fn test_inject_partial_txn_isolated() { - let temp = tempfile::tempdir().unwrap(); - - // inject one frame at a time - let mut injector = Injector::new(temp.path().join("data"), 10, 1000, None).unwrap(); - let mut frames = wal_log(); - - assert!(injector - .inject_frame(frames.next().unwrap()) - .unwrap() - .is_none()); - let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); - assert!(conn - .query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) - .is_err()); - - while injector - .inject_frame(frames.next().unwrap()) - .unwrap() - .is_none() - {} - - // reset schema - conn.pragma_update(None, "writable_schema", "reset") - .unwrap(); - conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) - .unwrap(); - } + fn flush(&mut self) -> impl Future>> + Send; } diff --git a/libsql-replication/src/injector/headers.rs b/libsql-replication/src/injector/sqlite_injector/headers.rs similarity index 100% rename from libsql-replication/src/injector/headers.rs rename to libsql-replication/src/injector/sqlite_injector/headers.rs diff --git a/libsql-replication/src/injector/injector_wal.rs b/libsql-replication/src/injector/sqlite_injector/injector_wal.rs similarity index 100% rename from libsql-replication/src/injector/injector_wal.rs rename to libsql-replication/src/injector/sqlite_injector/injector_wal.rs diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs new file mode 100644 index 0000000000..dea78ce4b5 --- /dev/null +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -0,0 +1,345 @@ +use std::path::Path; +use std::sync::Arc; +use std::{collections::VecDeque, path::PathBuf}; + +use parking_lot::Mutex; +use rusqlite::OpenFlags; +use tokio::task::spawn_blocking; + +use crate::frame::{Frame, FrameNo}; + +use self::injector_wal::{ + InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, +}; + +use super::error::Result; +use super::{Error, Injector}; + +mod headers; +mod injector_wal; + +pub type FrameBuffer = Arc>>; + +pub struct SqliteInjector { + pub(in super::super) inner: Arc>, +} + +impl Injector for SqliteInjector { + async fn inject_frame( + &mut self, + frame: Frame, + ) -> Result> { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().inject_frame(frame) + }).await.unwrap() + } + + async fn rollback(&mut self) { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().rollback() + }).await.unwrap(); + } + + async fn flush(&mut self) -> Result> { + let inner = self.inner.clone(); + spawn_blocking(move || { + inner.lock().flush() + }).await.unwrap() + } +} + +impl SqliteInjector { + pub async fn new( + path: PathBuf, + capacity: usize, + auto_checkpoint: u32, + encryption_config: Option, + ) ->super::Result { + let inner = spawn_blocking(move || { + SqliteInjectorInner::new(path, capacity, auto_checkpoint, encryption_config) + }).await.unwrap()?; + + Ok(Self { + inner: Arc::new(Mutex::new(inner)) + }) + } +} + +pub(in super::super) struct SqliteInjectorInner { + /// The injector is in a transaction state + is_txn: bool, + /// Buffer for holding current transaction frames + buffer: FrameBuffer, + /// Maximum capacity of the frame buffer + capacity: usize, + /// Injector connection + // connection must be dropped before the hook context + connection: Arc>>, + biggest_uncommitted_seen: FrameNo, + + // Connection config items used to recreate the injection connection + path: PathBuf, + encryption_config: Option, + auto_checkpoint: u32, +} + +/// Methods from this trait are called before and after performing a frame injection. +/// This trait trait is used to record the last committed frame_no to the log. +/// The implementer can persist the pre and post commit frame no, and compare them in the event of +/// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. +impl SqliteInjectorInner { + fn new( + path: impl AsRef, + capacity: usize, + auto_checkpoint: u32, + encryption_config: Option, + ) -> Result { + let path = path.as_ref().to_path_buf(); + + let buffer = FrameBuffer::default(); + let wal_manager = InjectorWalManager::new(buffer.clone()); + let connection = libsql_sys::Connection::open( + &path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + wal_manager, + auto_checkpoint, + encryption_config.clone(), + )?; + + Ok(Self { + is_txn: false, + buffer, + capacity, + connection: Arc::new(Mutex::new(connection)), + biggest_uncommitted_seen: 0, + + path, + encryption_config, + auto_checkpoint, + }) + } + + /// Inject a frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). + pub fn inject_frame(&mut self, frame: Frame) -> Result, Error> { + let frame_close_txn = frame.header().size_after.get() != 0; + self.buffer.lock().push_back(frame); + if frame_close_txn || self.buffer.lock().len() >= self.capacity { + return self.flush(); + } + + Ok(None) + } + + pub fn rollback(&mut self) { + self.clear_buffer(); + let conn = self.connection.lock(); + let mut rollback = conn.prepare_cached("ROLLBACK").unwrap(); + let _ = rollback.execute(()); + self.is_txn = false; + } + + /// Flush the buffer to libsql WAL. + /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame + /// are then injected into the wal. + pub fn flush(&mut self) -> Result, Error> { + match self.try_flush() { + Err(e) => { + // something went wrong, rollback the connection to make sure we can retry in a + // clean state + self.biggest_uncommitted_seen = 0; + self.rollback(); + Err(e) + } + Ok(ret) => Ok(ret), + } + } + + fn try_flush(&mut self) -> Result, Error> { + if !self.is_txn { + self.begin_txn()?; + } + + let lock = self.buffer.lock(); + // the frames in the buffer are either monotonically increasing (log) or decreasing + // (snapshot). Either way, we want to find the biggest frameno we're about to commit, and + // that is either the front or the back of the buffer + let last_frame_no = match lock.back().zip(lock.front()) { + Some((b, f)) => f.header().frame_no.get().max(b.header().frame_no.get()), + None => { + tracing::trace!("nothing to inject"); + return Ok(None); + } + }; + + self.biggest_uncommitted_seen = self.biggest_uncommitted_seen.max(last_frame_no); + + drop(lock); + + let connection = self.connection.lock(); + // use prepare cached to avoid parsing the same statement over and over again. + let mut stmt = + connection.prepare_cached("INSERT INTO libsql_temp_injection VALUES (42)")?; + + // We execute the statement, and then force a call to xframe if necesacary. If the execute + // succeeds, then xframe wasn't called, in this case, we call cache_flush, and then process + // the error. + // It is unexpected that execute flushes, but it is possible, so we handle that case. + match stmt.execute(()).and_then(|_| connection.cache_flush()) { + Ok(_) => panic!("replication hook was not called"), + Err(e) => { + if let Some(e) = e.sqlite_error() { + if e.extended_code == LIBSQL_INJECT_OK { + // refresh schema + connection.pragma_update(None, "writable_schema", "reset")?; + let mut rollback = connection.prepare_cached("ROLLBACK")?; + let _ = rollback.execute(()); + self.is_txn = false; + assert!(self.buffer.lock().is_empty()); + let commit_frame_no = self.biggest_uncommitted_seen; + self.biggest_uncommitted_seen = 0; + return Ok(Some(commit_frame_no)); + } else if e.extended_code == LIBSQL_INJECT_OK_TXN { + self.is_txn = true; + assert!(self.buffer.lock().is_empty()); + return Ok(None); + } else if e.extended_code == LIBSQL_INJECT_FATAL { + return Err(Error::FatalInjectError); + } + } + + Err(Error::FatalInjectError) + } + } + } + + fn begin_txn(&mut self) -> Result<(), Error> { + let mut conn = self.connection.lock(); + + { + let wal_manager = InjectorWalManager::new(self.buffer.clone()); + let new_conn = libsql_sys::Connection::open( + &self.path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + wal_manager, + self.auto_checkpoint, + self.encryption_config.clone(), + )?; + + let _ = std::mem::replace(&mut *conn, new_conn); + } + + conn.pragma_update(None, "writable_schema", "true")?; + + let mut stmt = conn.prepare_cached("BEGIN IMMEDIATE")?; + stmt.execute(())?; + // we create a dummy table. This table MUST not be persisted, otherwise the replica schema + // would differ with the primary's. + let mut stmt = + conn.prepare_cached("CREATE TABLE IF NOT EXISTS libsql_temp_injection (x)")?; + stmt.execute(())?; + + Ok(()) + } + + pub fn clear_buffer(&mut self) { + self.buffer.lock().clear() + } + + #[cfg(test)] + pub fn is_txn(&self) -> bool { + self.is_txn + } +} + +#[cfg(test)] +mod test { + use crate::frame::FrameBorrowed; + use std::mem::size_of; + + use super::*; + /// this this is generated by creating a table test, inserting 5 rows into it, and then + /// truncating the wal file of it's header. + const WAL: &[u8] = include_bytes!("../../../assets/test/test_wallog"); + + fn wal_log() -> impl Iterator { + WAL.chunks(size_of::()) + .map(|b| Frame::try_from(b).unwrap()) + } + + #[test] + fn test_simple_inject_frames() { + let temp = tempfile::tempdir().unwrap(); + + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); + let log = wal_log(); + for frame in log { + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_frames_split_txn() { + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); + let log = wal_log(); + for frame in log { + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_partial_txn_isolated() { + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); + let mut frames = wal_log(); + + assert!(injector + .inject_frame(frames.next().unwrap()) + .unwrap() + .is_none()); + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + assert!(conn + .query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .is_err()); + + while injector + .inject_frame(frames.next().unwrap()) + .unwrap() + .is_none() + {} + + // reset schema + conn.pragma_update(None, "writable_schema", "reset") + .unwrap(); + conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .unwrap(); + } +} diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index bc1eada7f8..31c766faad 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -1,14 +1,11 @@ use std::path::PathBuf; -use std::sync::Arc; -use parking_lot::Mutex; -use tokio::task::spawn_blocking; use tokio::time::Duration; use tokio_stream::{Stream, StreamExt}; use tonic::{Code, Status}; use crate::frame::{Frame, FrameNo}; -use crate::injector::Injector; +use crate::injector::{Injector, SqliteInjector}; use crate::rpc::replication::{ Frame as RpcFrame, NAMESPACE_DOESNT_EXIST, NEED_SNAPSHOT_ERROR_MSG, NO_HELLO_ERROR_MSG, }; @@ -137,9 +134,9 @@ where /// The `Replicator`'s duty is to download frames from the primary, and pass them to the injector at /// transaction boundaries. -pub struct Replicator { +pub struct Replicator { client: C, - injector: Arc>, + injector: I, state: ReplicatorState, frames_synced: usize, } @@ -154,33 +151,42 @@ enum ReplicatorState { Exit, } -impl Replicator { +impl Replicator +where + C: ReplicatorClient, +{ /// Creates a replicator for the db file pointed at by `db_path` - pub async fn new( + pub async fn new_sqlite( client: C, db_path: PathBuf, auto_checkpoint: u32, encryption_config: Option, ) -> Result { - let injector = { - let db_path = db_path.clone(); - spawn_blocking(move || { - Injector::new( - db_path, - INJECTOR_BUFFER_CAPACITY, - auto_checkpoint, - encryption_config, - ) - }) - .await?? - }; + let injector = SqliteInjector::new( + db_path.clone(), + INJECTOR_BUFFER_CAPACITY, + auto_checkpoint, + encryption_config, + ) + .await?; + + Ok(Self::new(client, injector)) + } +} - Ok(Self { +impl Replicator +where + C: ReplicatorClient, + I: Injector, +{ + + pub fn new(client: C, injector: I) -> Self { + Self { client, - injector: Arc::new(Mutex::new(injector)), + injector, state: ReplicatorState::NeedHandshake, frames_synced: 0, - }) + } } /// for a handshake on next call to replicate. @@ -250,7 +256,7 @@ impl Replicator { // in case of error we rollback the current injector transaction, and start over. if ret.is_err() { self.client.rollback(); - self.injector.lock().rollback(); + self.injector.rollback().await; } self.state = match ret { @@ -293,7 +299,8 @@ impl Replicator { } async fn load_snapshot(&mut self) -> Result<(), Error> { - self.injector.lock().clear_buffer(); + self.client.rollback(); + self.injector.rollback().await; loop { match self.client.snapshot().await { Ok(mut stream) => { @@ -315,26 +322,22 @@ impl Replicator { async fn inject_frame(&mut self, frame: Frame) -> Result<(), Error> { self.frames_synced += 1; - let injector = self.injector.clone(); - match spawn_blocking(move || injector.lock().inject_frame(frame)).await? { - Ok(Some(commit_fno)) => { + match self.injector.inject_frame(frame).await? { + Some(commit_fno) => { self.client.commit_frame_no(commit_fno).await?; } - Ok(None) => (), - Err(e) => Err(e)?, + None => (), } Ok(()) } pub async fn flush(&mut self) -> Result<(), Error> { - let injector = self.injector.clone(); - match spawn_blocking(move || injector.lock().flush()).await? { - Ok(Some(commit_fno)) => { + match self.injector.flush().await? { + Some(commit_fno) => { self.client.commit_frame_no(commit_fno).await?; } - Ok(None) => (), - Err(e) => Err(e)?, + None => (), } Ok(()) @@ -395,7 +398,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); @@ -438,7 +441,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -482,7 +485,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -526,7 +529,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -568,7 +571,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -610,7 +613,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); replicator.state = ReplicatorState::NeedSnapshot; @@ -653,7 +656,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); // we assume that we already received the handshake and the handshake is not valid anymore @@ -696,7 +699,7 @@ mod test { fn rollback(&mut self) {} } - let mut replicator = Replicator::new(Client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(Client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); replicator.state = ReplicatorState::NeedHandshake; @@ -784,7 +787,7 @@ mod test { committed_frame_no: None, }; - let mut replicator = Replicator::new(client, tmp.path().to_path_buf(), 10000, None) + let mut replicator = Replicator::new_sqlite(client, tmp.path().to_path_buf(), 10000, None) .await .unwrap(); @@ -795,7 +798,7 @@ mod test { replicator.try_replicate_step().await.unwrap_err(), Error::Client(_) )); - assert!(!replicator.injector.lock().is_txn()); + assert!(!replicator.injector.inner.lock().is_txn()); assert!(replicator.client_mut().committed_frame_no.is_none()); assert_eq!(replicator.state, ReplicatorState::NeedHandshake); @@ -805,7 +808,7 @@ mod test { replicator.client_mut().should_error = false; replicator.try_replicate_step().await.unwrap(); - assert!(!replicator.injector.lock().is_txn()); + assert!(!replicator.injector.inner.lock().is_txn()); assert_eq!(replicator.state, ReplicatorState::Exit); assert_eq!(replicator.client_mut().committed_frame_no, Some(6)); } diff --git a/libsql/src/replication/mod.rs b/libsql/src/replication/mod.rs index 69cc0b5db2..2f4e9b49c0 100644 --- a/libsql/src/replication/mod.rs +++ b/libsql/src/replication/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use std::time::Duration; pub use libsql_replication::frame::{Frame, FrameNo}; +use libsql_replication::injector::SqliteInjector; use libsql_replication::replicator::{Either, Replicator}; pub use libsql_replication::snapshot::SnapshotFile; @@ -129,7 +130,7 @@ impl Writer { #[derive(Clone)] pub(crate) struct EmbeddedReplicator { - replicator: Arc>>>, + replicator: Arc, SqliteInjector>>>, bg_abort: Option>, last_frames_synced: Arc, } @@ -149,7 +150,7 @@ impl EmbeddedReplicator { perodic_sync: Option, ) -> Result { let replicator = Arc::new(Mutex::new( - Replicator::new( + Replicator::new_sqlite( Either::Left(client), db_path, auto_checkpoint, @@ -193,7 +194,7 @@ impl EmbeddedReplicator { encryption_config: Option, ) -> Result { let replicator = Arc::new(Mutex::new( - Replicator::new( + Replicator::new_sqlite( Either::Right(client), db_path, auto_checkpoint, From 4b5baacad57965b639dcab274a4a97f8b0ed3abd Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 11:57:15 +0200 Subject: [PATCH 55/70] introduce libsql injector --- Cargo.lock | 2 + Cargo.toml | 1 + libsql-replication/Cargo.toml | 1 + .../proto/replication_log.proto | 6 +++ libsql-replication/src/frame.rs | 3 +- libsql-replication/src/generated/wal_log.rs | 42 +++++++++++++++ libsql-replication/src/injector/error.rs | 5 +- .../src/injector/libsql_injector.rs | 44 +++++++++++++++ libsql-replication/src/injector/mod.rs | 2 + .../src/injector/sqlite_injector/mod.rs | 12 ++--- libsql-replication/src/rpc.rs | 5 +- libsql-server/Cargo.toml | 2 +- .../src/replication/replicator_client.rs | 6 ++- libsql-server/src/rpc/replication_log.rs | 5 ++ libsql-wal/Cargo.toml | 3 +- libsql-wal/src/replication/injector.rs | 23 ++++---- libsql-wal/src/segment/current.rs | 4 +- libsql-wal/src/shared_wal.rs | 8 +-- libsql-wal/src/transaction.rs | 53 +++++++++++++++++-- libsql/src/replication/remote_client.rs | 3 +- 20 files changed, 196 insertions(+), 34 deletions(-) create mode 100644 libsql-replication/src/injector/libsql_injector.rs diff --git a/Cargo.lock b/Cargo.lock index 17cfc0e090..7e19e03e9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3702,6 +3702,7 @@ name = "libsql-wal" version = "0.1.0" dependencies = [ "arc-swap", + "async-lock 3.4.0", "async-stream", "aws-config 1.5.4", "aws-credential-types 1.2.0", @@ -3779,6 +3780,7 @@ dependencies = [ "cbc", "libsql-rusqlite", "libsql-sys", + "libsql-wal", "parking_lot", "prost", "prost-build", diff --git a/Cargo.toml b/Cargo.toml index 9381fb83f3..685f14964f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ rusqlite = { package = "libsql-rusqlite", path = "vendored/rusqlite", version = ] } hyper = { version = "0.14" } tower = { version = "0.4.13" } +zerocopy = { version = "0.7.32", features = ["derive", "alloc"] } # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index d2a9431cba..2a03d362bf 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" tonic = { version = "0.11", features = ["tls"] } prost = "0.12" libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } +libsql-wal = { path = "../libsql-wal/" } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } diff --git a/libsql-replication/proto/replication_log.proto b/libsql-replication/proto/replication_log.proto index 6208874609..b3be419319 100644 --- a/libsql-replication/proto/replication_log.proto +++ b/libsql-replication/proto/replication_log.proto @@ -9,6 +9,12 @@ message LogOffset { message HelloRequest { optional uint64 handshake_version = 1; + enum WalFlavor { + Sqlite = 0; + Libsql = 1; + } + // the type of wal that the client is expecting + optional WalFlavor wal_flavor = 2; } message HelloResponse { diff --git a/libsql-replication/src/frame.rs b/libsql-replication/src/frame.rs index a6a2854e52..55b5b778b5 100644 --- a/libsql-replication/src/frame.rs +++ b/libsql-replication/src/frame.rs @@ -13,7 +13,6 @@ use crate::LIBSQL_PAGE_SIZE; pub type FrameNo = u64; /// The file header for the WAL log. All fields are represented in little-endian ordering. -/// See `encode` and `decode` for actual layout. // repr C for stable sizing #[repr(C)] #[derive(Debug, Clone, Copy, zerocopy::FromZeroes, zerocopy::FromBytes, zerocopy::AsBytes)] @@ -22,7 +21,7 @@ pub struct FrameHeader { pub frame_no: lu64, /// Rolling checksum of all the previous frames, including this one. pub checksum: lu64, - /// page number, if frame_type is FrameType::Page + /// page number pub page_no: lu32, /// Size of the database (in page) after committing the transaction. This is passed from sqlite, /// and serves as commit transaction boundary diff --git a/libsql-replication/src/generated/wal_log.rs b/libsql-replication/src/generated/wal_log.rs index 2d7330e732..441881c4a7 100644 --- a/libsql-replication/src/generated/wal_log.rs +++ b/libsql-replication/src/generated/wal_log.rs @@ -10,6 +10,48 @@ pub struct LogOffset { pub struct HelloRequest { #[prost(uint64, optional, tag = "1")] pub handshake_version: ::core::option::Option, + /// the type of wal that the client is expecting + #[prost(enumeration = "hello_request::WalFlavor", optional, tag = "2")] + pub wal_flavor: ::core::option::Option, +} +/// Nested message and enum types in `HelloRequest`. +pub mod hello_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum WalFlavor { + Sqlite = 0, + Libsql = 1, + } + impl WalFlavor { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + WalFlavor::Sqlite => "Sqlite", + WalFlavor::Libsql => "Libsql", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Sqlite" => Some(Self::Sqlite), + "Libsql" => Some(Self::Libsql), + _ => None, + } + } + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index b1cebfe28b..225960c4d1 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,4 +1,5 @@ pub type Result = std::result::Result; +pub type BoxError = Box; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -6,6 +7,6 @@ pub enum Error { Io(#[from] std::io::Error), #[error("SQLite error: {0}")] Sqlite(#[from] rusqlite::Error), - #[error("A fatal error occured injecting frames")] - FatalInjectError, + #[error("A fatal error occured injecting frames: {0}")] + FatalInjectError(BoxError), } diff --git a/libsql-replication/src/injector/libsql_injector.rs b/libsql-replication/src/injector/libsql_injector.rs new file mode 100644 index 0000000000..946d35e547 --- /dev/null +++ b/libsql-replication/src/injector/libsql_injector.rs @@ -0,0 +1,44 @@ +use std::mem::size_of; + +use libsql_wal::io::StdIO; +use libsql_wal::replication::injector::Injector; +use libsql_wal::segment::Frame as WalFrame; +use zerocopy::{AsBytes, FromZeroes}; + +use crate::frame::{Frame, FrameNo}; + +use super::error::{Error, Result}; + +pub struct LibsqlInjector { + injector: Injector, +} + +impl super::Injector for LibsqlInjector { + async fn inject_frame(&mut self, frame: Frame) -> Result> { + // this is a bit annoying be we want to read the frame, and it has to be aligned, so we + // must copy it... + // FIXME: optimize this. + let mut wal_frame = WalFrame::new_box_zeroed(); + if frame.bytes().len() != size_of::() { + todo!("invalid frame"); + } + wal_frame.as_bytes_mut().copy_from_slice(&frame.bytes()[..]); + Ok(self + .injector + .insert_frame(wal_frame) + .await + .map_err(|e| Error::FatalInjectError(e.into()))?) + } + + async fn rollback(&mut self) { + self.injector.rollback(); + } + + async fn flush(&mut self) -> Result> { + self.injector + .flush(None) + .await + .map_err(|e| Error::FatalInjectError(e.into()))?; + Ok(None) + } +} diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 1d69ae0aab..20a81cfa01 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,6 +1,7 @@ use std::future::Future; pub use sqlite_injector::SqliteInjector; +pub use libsql_injector::LibsqlInjector; use crate::frame::{Frame, FrameNo}; @@ -9,6 +10,7 @@ pub use error::Error; mod error; mod sqlite_injector; +mod libsql_injector; pub trait Injector { /// Inject a singular frame. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index dea78ce4b5..545fbe810d 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -192,8 +192,8 @@ impl SqliteInjectorInner { match stmt.execute(()).and_then(|_| connection.cache_flush()) { Ok(_) => panic!("replication hook was not called"), Err(e) => { - if let Some(e) = e.sqlite_error() { - if e.extended_code == LIBSQL_INJECT_OK { + if let Some(err) = e.sqlite_error() { + if err.extended_code == LIBSQL_INJECT_OK { // refresh schema connection.pragma_update(None, "writable_schema", "reset")?; let mut rollback = connection.prepare_cached("ROLLBACK")?; @@ -203,16 +203,16 @@ impl SqliteInjectorInner { let commit_frame_no = self.biggest_uncommitted_seen; self.biggest_uncommitted_seen = 0; return Ok(Some(commit_frame_no)); - } else if e.extended_code == LIBSQL_INJECT_OK_TXN { + } else if err.extended_code == LIBSQL_INJECT_OK_TXN { self.is_txn = true; assert!(self.buffer.lock().is_empty()); return Ok(None); - } else if e.extended_code == LIBSQL_INJECT_FATAL { - return Err(Error::FatalInjectError); + } else if err.extended_code == LIBSQL_INJECT_FATAL { + return Err(Error::FatalInjectError(e.into())); } } - Err(Error::FatalInjectError) + Err(Error::FatalInjectError(e.into())) } } } diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index ebc92cf10c..a9b172db20 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -25,6 +25,8 @@ pub mod replication { #![allow(clippy::all)] use uuid::Uuid; + + use self::hello_request::WalFlavor; include!("generated/wal_log.rs"); pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO"; @@ -46,9 +48,10 @@ pub mod replication { } impl HelloRequest { - pub fn new() -> Self { + pub fn new(wal_flavor: WalFlavor) -> Self { Self { handshake_version: Some(1), + wal_flavor: Some(wal_flavor.into()) } } } diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 6763c02dfb..934a400786 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -83,7 +83,7 @@ url = { version = "2.3", features = ["serde"] } uuid = { version = "1.3", features = ["v4", "serde", "v7"] } aes = { version = "0.8.3", optional = true } cbc = { version = "0.1.2", optional = true } -zerocopy = { version = "0.7.28", features = ["derive", "alloc"] } +zerocopy = { workspace = true } hashbrown = { version = "0.14.3", features = ["serde"] } hdrhistogram = "7.5.4" crossbeam = "0.8.4" diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index 4d12ff7f83..d68c259dc9 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -7,6 +7,7 @@ use futures::TryStreamExt; use libsql_replication::frame::Frame; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ verify_session_token, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, SESSION_TOKEN_KEY, @@ -35,6 +36,7 @@ pub struct Client { // the primary current replication index, as reported by the last handshake pub primary_replication_index: Option, store: NamespaceStore, + wal_flavor: WalFlavor, } impl Client { @@ -44,6 +46,7 @@ impl Client { path: &Path, meta_store_handle: MetaStoreHandle, store: NamespaceStore, + wal_flavor: WalFlavor, ) -> crate::Result { let (current_frame_no_notifier, _) = watch::channel(None); let meta = WalIndexMeta::open(path).await?; @@ -57,6 +60,7 @@ impl Client { meta_store_handle, primary_replication_index: None, store, + wal_flavor, }) } @@ -96,7 +100,7 @@ impl ReplicatorClient for Client { #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { tracing::debug!("Attempting to perform handshake with primary."); - let req = self.make_request(HelloRequest::new()); + let req = self.make_request(HelloRequest::new(self.wal_flavor)); let resp = self.client.hello(req).await?; let hello = resp.into_inner(); verify_session_token(&hello.session_token).map_err(Error::Client)?; diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index c0b216739e..628cb4a01d 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -7,6 +7,7 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; +use libsql_replication::rpc::replication::hello_request::WalFlavor; pub use libsql_replication::rpc::replication as rpc; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ @@ -355,6 +356,10 @@ impl ReplicationLog for ReplicationLogService { } } + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")) + } + let (logger, config, version, _, _) = self.logger_from_namespace(namespace, &req, false).await?; diff --git a/libsql-wal/Cargo.toml b/libsql-wal/Cargo.toml index 9624596c28..f24f2e4c59 100644 --- a/libsql-wal/Cargo.toml +++ b/libsql-wal/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] arc-swap = "1.7.1" async-stream = "0.3.5" +async-lock = "3.4.0" bitflags = "2.5.0" bytes = "1.6.0" chrono = "0.4.38" @@ -29,7 +30,7 @@ tokio-stream = "0.1.15" tracing = "0.1.40" uuid = { version = "1.8.0", features = ["v4"] } walkdir = "2.5.0" -zerocopy = { version = "0.7.32", features = ["derive", "alloc"] } +zerocopy = { workspace = true } aws-config = { version = "1", optional = true, features = ["behavior-version-latest"] } aws-sdk-s3 = { version = "1", optional = true } diff --git a/libsql-wal/src/replication/injector.rs b/libsql-wal/src/replication/injector.rs index 66710bbb22..c3642e196e 100644 --- a/libsql-wal/src/replication/injector.rs +++ b/libsql-wal/src/replication/injector.rs @@ -6,23 +6,23 @@ use crate::error::Result; use crate::io::Io; use crate::segment::Frame; use crate::shared_wal::SharedWal; -use crate::transaction::TxGuard; +use crate::transaction::TxGuardOwned; /// The injector takes frames and injects them in the wal. -pub struct Injector<'a, IO: Io> { +pub struct Injector { // The wal to which we are injecting wal: Arc>, buffer: Vec>, /// capacity of the frame buffer capacity: usize, - tx: TxGuard<'a, IO::File>, + tx: TxGuardOwned, max_tx_frame_no: u64, } -impl<'a, IO: Io> Injector<'a, IO> { +impl Injector { pub fn new( wal: Arc>, - tx: TxGuard<'a, IO::File>, + tx: TxGuardOwned, buffer_capacity: usize, ) -> Result { Ok(Self { @@ -34,7 +34,7 @@ impl<'a, IO: Io> Injector<'a, IO> { }) } - pub async fn insert_frame(&mut self, frame: Box) -> Result<()> { + pub async fn insert_frame(&mut self, frame: Box) -> Result> { let size_after = frame.size_after(); self.max_tx_frame_no = self.max_tx_frame_no.max(frame.header().frame_no()); self.buffer.push(frame); @@ -43,10 +43,10 @@ impl<'a, IO: Io> Injector<'a, IO> { self.flush(size_after).await?; } - Ok(()) + Ok(size_after.map(|_| self.max_tx_frame_no)) } - async fn flush(&mut self, size_after: Option) -> Result<()> { + pub async fn flush(&mut self, size_after: Option) -> Result<()> { let buffer = std::mem::take(&mut self.buffer); let current = self.wal.current.load(); let commit_data = size_after.map(|size| (size, self.max_tx_frame_no)); @@ -60,6 +60,11 @@ impl<'a, IO: Io> Injector<'a, IO> { Ok(()) } + + pub fn rollback(&mut self) { + self.buffer.clear(); + self.tx.reset(0); + } } #[cfg(test)] @@ -89,7 +94,7 @@ mod test { let mut tx = crate::transaction::Transaction::Read(replica_shared.begin_read(42)); replica_shared.upgrade(&mut tx).unwrap(); - let guard = tx.as_write_mut().unwrap().lock(); + let guard = tx.into_write().unwrap_or_else(|_| panic!()).into_lock_owned(); let mut injector = Injector::new(replica_shared.clone(), guard, 10).unwrap(); primary_conn.execute("create table test (x)", ()).unwrap(); diff --git a/libsql-wal/src/segment/current.rs b/libsql-wal/src/segment/current.rs index d8d720a145..bda6d5742a 100644 --- a/libsql-wal/src/segment/current.rs +++ b/libsql-wal/src/segment/current.rs @@ -22,7 +22,7 @@ use crate::io::file::FileExt; use crate::io::Inspect; use crate::segment::{checked_frame_offset, SegmentFlags}; use crate::segment::{frame_offset, page_offset, sealed::SealedSegment}; -use crate::transaction::{Transaction, TxGuard}; +use crate::transaction::{Transaction, TxGuard, TxGuardOwned}; use crate::{LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION}; use super::list::SegmentList; @@ -125,7 +125,7 @@ impl CurrentSegment { frames: Vec>, // (size_after, last_frame_no) commit_data: Option<(u32, u64)>, - tx: &mut TxGuard<'_, F>, + tx: &mut TxGuardOwned, ) -> Result>> where F: FileExt, diff --git a/libsql-wal/src/shared_wal.rs b/libsql-wal/src/shared_wal.rs index 09a2747c5a..461ad13e03 100644 --- a/libsql-wal/src/shared_wal.rs +++ b/libsql-wal/src/shared_wal.rs @@ -20,7 +20,7 @@ use libsql_sys::name::NamespaceName; #[derive(Default)] pub struct WalLock { - pub(crate) tx_id: Arc>>, + pub(crate) tx_id: Arc>>, /// When a writer is popped from the write queue, its write transaction may not be reading from the most recent /// snapshot. In this case, we return `SQLITE_BUSY_SNAPHSOT` to the caller. If no reads were performed /// with that transaction before upgrading, then the caller will call us back immediately after re-acquiring @@ -108,7 +108,7 @@ impl SharedWal { Some(id) if id == read_tx.conn_id => { tracing::trace!("taking reserved slot"); reserved.take(); - let lock = self.wal_lock.tx_id.lock(); + let lock = self.wal_lock.tx_id.lock_blocking(); let write_tx = self.acquire_write(read_tx, lock, reserved)?; *tx = Transaction::Write(write_tx); return Ok(()); @@ -117,7 +117,7 @@ impl SharedWal { } } - let lock = self.wal_lock.tx_id.lock(); + let lock = self.wal_lock.tx_id.lock_blocking(); match *lock { None if self.wal_lock.waiters.is_empty() => { let write_tx = @@ -144,7 +144,7 @@ impl SharedWal { fn acquire_write( &self, read_tx: &ReadTransaction, - mut tx_id_lock: MutexGuard>, + mut tx_id_lock: async_lock::MutexGuard>, mut reserved: MutexGuard>, ) -> Result> { // we read two fields in the header. There is no risk that a transaction commit in diff --git a/libsql-wal/src/transaction.rs b/libsql-wal/src/transaction.rs index 723cffeae1..f2cdd5be70 100644 --- a/libsql-wal/src/transaction.rs +++ b/libsql-wal/src/transaction.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use std::time::Instant; use libsql_sys::name::NamespaceName; -use parking_lot::{ArcMutexGuard, RawMutex}; use tokio::sync::mpsc; use crate::checkpointer::CheckpointMessage; @@ -31,6 +30,14 @@ impl Transaction { } } + pub fn into_write(self) -> Result, Self> { + if let Self::Write(v) = self { + Ok(v) + } else { + Err(self) + } + } + pub fn max_frame_no(&self) -> u64 { match self { Transaction::Write(w) => w.next_frame_no - 1, @@ -147,8 +154,27 @@ pub struct WriteTransaction { pub recompute_checksum: Option, } +pub struct TxGuardOwned { + _lock: async_lock::MutexGuardArc>, + inner: WriteTransaction, +} + +impl Deref for TxGuardOwned { + type Target = WriteTransaction; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for TxGuardOwned { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + pub struct TxGuard<'a, F> { - _lock: ArcMutexGuard>, + _lock: async_lock::MutexGuardArc>, inner: &'a mut WriteTransaction, } @@ -189,7 +215,7 @@ impl WriteTransaction { todo!("txn has already been commited"); } - let g = self.wal_lock.tx_id.lock_arc(); + let g = self.wal_lock.tx_id.lock_arc_blocking(); match *g { // we still hold the lock, we can proceed Some(id) if self.id == id => TxGuard { @@ -202,6 +228,25 @@ impl WriteTransaction { } } + pub fn into_lock_owned(self) -> TxGuardOwned { + if self.is_commited { + tracing::error!("transaction already commited"); + todo!("txn has already been commited"); + } + + let g = self.wal_lock.tx_id.lock_arc_blocking(); + match *g { + // we still hold the lock, we can proceed + Some(id) if self.id == id => TxGuardOwned { + _lock: g, + inner: self, + }, + // Somebody took the lock from us + Some(_) => todo!("lock stolen"), + None => todo!("not a transaction"), + } + } + pub fn reset(&mut self, savepoint_id: usize) { if savepoint_id >= self.savepoints.len() { unreachable!("savepoint doesn't exist"); @@ -231,7 +276,7 @@ impl WriteTransaction { let Self { wal_lock, read_tx, .. } = self; - let mut lock = wal_lock.tx_id.lock(); + let mut lock = wal_lock.tx_id.lock_blocking(); match *lock { Some(lock_id) if lock_id == read_tx.id => { lock.take(); diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index d0052f50d9..79cffb1c38 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -8,6 +8,7 @@ use futures::StreamExt as _; use libsql_replication::frame::{Frame, FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; @@ -116,7 +117,7 @@ impl RemoteClient { self.dirty = false; } let prefetch = self.session_token.is_some(); - let hello_req = self.make_request(HelloRequest::new()); + let hello_req = self.make_request(HelloRequest::new(WalFlavor::Sqlite)); let log_offset_req = self.make_request(LogOffset { next_offset: self.next_offset(), }); From 566664e9a0acc7bbd9544c49f5f573eaec6a0a7e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 15:14:01 +0200 Subject: [PATCH 56/70] fmt --- bottomless/src/replicator.rs | 3 +- libsql-replication/src/injector/error.rs | 2 +- libsql-replication/src/injector/mod.rs | 6 +-- .../src/injector/sqlite_injector/mod.rs | 38 +++++++++---------- libsql-replication/src/replicator.rs | 3 +- libsql-replication/src/rpc.rs | 2 +- libsql-server/src/rpc/replication_log.rs | 4 +- libsql-wal/src/replication/injector.rs | 5 ++- 8 files changed, 33 insertions(+), 30 deletions(-) diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index 26e190df66..4e92824778 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -1455,7 +1455,8 @@ impl Replicator { 4096, libsql_sys::connection::NO_AUTOCHECKPOINT, encryption_config, - ).await?; + ) + .await?; let prefix = format!("{}-{}/", self.db_name, generation); let mut page_buf = { let mut v = Vec::with_capacity(page_size); diff --git a/libsql-replication/src/injector/error.rs b/libsql-replication/src/injector/error.rs index 225960c4d1..ac8f1be711 100644 --- a/libsql-replication/src/injector/error.rs +++ b/libsql-replication/src/injector/error.rs @@ -1,4 +1,4 @@ -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub type BoxError = Box; #[derive(Debug, thiserror::Error)] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 20a81cfa01..39df68c777 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,16 +1,16 @@ use std::future::Future; -pub use sqlite_injector::SqliteInjector; pub use libsql_injector::LibsqlInjector; +pub use sqlite_injector::SqliteInjector; use crate::frame::{Frame, FrameNo}; -use error::Result; pub use error::Error; +use error::Result; mod error; -mod sqlite_injector; mod libsql_injector; +mod sqlite_injector; pub trait Injector { /// Inject a singular frame. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index 545fbe810d..2f4193e469 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -25,28 +25,23 @@ pub struct SqliteInjector { } impl Injector for SqliteInjector { - async fn inject_frame( - &mut self, - frame: Frame, - ) -> Result> { + async fn inject_frame(&mut self, frame: Frame) -> Result> { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().inject_frame(frame) - }).await.unwrap() + spawn_blocking(move || inner.lock().inject_frame(frame)) + .await + .unwrap() } async fn rollback(&mut self) { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().rollback() - }).await.unwrap(); + spawn_blocking(move || inner.lock().rollback()) + .await + .unwrap(); } async fn flush(&mut self) -> Result> { let inner = self.inner.clone(); - spawn_blocking(move || { - inner.lock().flush() - }).await.unwrap() + spawn_blocking(move || inner.lock().flush()).await.unwrap() } } @@ -56,13 +51,15 @@ impl SqliteInjector { capacity: usize, auto_checkpoint: u32, encryption_config: Option, - ) ->super::Result { + ) -> super::Result { let inner = spawn_blocking(move || { SqliteInjectorInner::new(path, capacity, auto_checkpoint, encryption_config) - }).await.unwrap()?; + }) + .await + .unwrap()?; Ok(Self { - inner: Arc::new(Mutex::new(inner)) + inner: Arc::new(Mutex::new(inner)), }) } } @@ -278,7 +275,8 @@ mod test { fn test_simple_inject_frames() { let temp = tempfile::tempdir().unwrap(); - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap(); let log = wal_log(); for frame in log { injector.inject_frame(frame).unwrap(); @@ -298,7 +296,8 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap(); let log = wal_log(); for frame in log { injector.inject_frame(frame).unwrap(); @@ -318,7 +317,8 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); + let mut injector = + SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap(); let mut frames = wal_log(); assert!(injector diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index 31c766faad..ee75822676 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -168,7 +168,7 @@ where auto_checkpoint, encryption_config, ) - .await?; + .await?; Ok(Self::new(client, injector)) } @@ -179,7 +179,6 @@ where C: ReplicatorClient, I: Injector, { - pub fn new(client: C, injector: I) -> Self { Self { client, diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index a9b172db20..8e1165af65 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -51,7 +51,7 @@ pub mod replication { pub fn new(wal_flavor: WalFlavor) -> Self { Self { handshake_version: Some(1), - wal_flavor: Some(wal_flavor.into()) + wal_flavor: Some(wal_flavor.into()), } } } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 628cb4a01d..bf1840a5c6 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -7,8 +7,8 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; -use libsql_replication::rpc::replication::hello_request::WalFlavor; pub use libsql_replication::rpc::replication as rpc; +use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ Frame, Frames, HelloRequest, HelloResponse, LogOffset, NAMESPACE_DOESNT_EXIST, @@ -357,7 +357,7 @@ impl ReplicationLog for ReplicationLogService { } if let WalFlavor::Libsql = req.get_ref().wal_flavor() { - return Err(Status::invalid_argument("libsql wal not supported")) + return Err(Status::invalid_argument("libsql wal not supported")); } let (logger, config, version, _, _) = diff --git a/libsql-wal/src/replication/injector.rs b/libsql-wal/src/replication/injector.rs index c3642e196e..a922330102 100644 --- a/libsql-wal/src/replication/injector.rs +++ b/libsql-wal/src/replication/injector.rs @@ -94,7 +94,10 @@ mod test { let mut tx = crate::transaction::Transaction::Read(replica_shared.begin_read(42)); replica_shared.upgrade(&mut tx).unwrap(); - let guard = tx.into_write().unwrap_or_else(|_| panic!()).into_lock_owned(); + let guard = tx + .into_write() + .unwrap_or_else(|_| panic!()) + .into_lock_owned(); let mut injector = Injector::new(replica_shared.clone(), guard, 10).unwrap(); primary_conn.execute("create table test (x)", ()).unwrap(); From a2bdc805743f87f5c8cf2852c548578dbe496401 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 16:07:42 +0200 Subject: [PATCH 57/70] pass RpcFrame to client methods necessary to pass different underlying frames --- bottomless/src/replicator.rs | 7 ++- .../src/injector/libsql_injector.rs | 9 ++-- libsql-replication/src/injector/mod.rs | 5 +- .../src/injector/sqlite_injector/mod.rs | 5 +- libsql-replication/src/replicator.rs | 49 +++++++++++++------ .../src/replication/replicator_client.rs | 15 +++--- libsql/src/replication/local_client.rs | 14 ++++-- libsql/src/replication/remote_client.rs | 17 ++++--- 8 files changed, 79 insertions(+), 42 deletions(-) diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index 4e92824778..cd37a70165 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -18,6 +18,7 @@ use aws_sdk_s3::{Client, Config}; use bytes::{Buf, Bytes}; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; use libsql_replication::injector::Injector as _; +use libsql_replication::rpc::replication::Frame as RpcFrame; use libsql_sys::{Cipher, EncryptionConfig}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -1554,7 +1555,11 @@ impl Replicator { }, page_buf.as_slice(), ); - injector.inject_frame(frame_to_inject).await?; + let frame = RpcFrame { + data: frame_to_inject.bytes(), + timestamp: None, + }; + injector.inject_frame(frame).await?; applied_wal_frame = true; } } diff --git a/libsql-replication/src/injector/libsql_injector.rs b/libsql-replication/src/injector/libsql_injector.rs index 946d35e547..f867a29245 100644 --- a/libsql-replication/src/injector/libsql_injector.rs +++ b/libsql-replication/src/injector/libsql_injector.rs @@ -5,7 +5,8 @@ use libsql_wal::replication::injector::Injector; use libsql_wal::segment::Frame as WalFrame; use zerocopy::{AsBytes, FromZeroes}; -use crate::frame::{Frame, FrameNo}; +use crate::frame::FrameNo; +use crate::rpc::replication::Frame as RpcFrame; use super::error::{Error, Result}; @@ -14,15 +15,15 @@ pub struct LibsqlInjector { } impl super::Injector for LibsqlInjector { - async fn inject_frame(&mut self, frame: Frame) -> Result> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result> { // this is a bit annoying be we want to read the frame, and it has to be aligned, so we // must copy it... // FIXME: optimize this. let mut wal_frame = WalFrame::new_box_zeroed(); - if frame.bytes().len() != size_of::() { + if frame.data.len() != size_of::() { todo!("invalid frame"); } - wal_frame.as_bytes_mut().copy_from_slice(&frame.bytes()[..]); + wal_frame.as_bytes_mut().copy_from_slice(&frame.data[..]); Ok(self .injector .insert_frame(wal_frame) diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 39df68c777..3712458d2f 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,9 +1,10 @@ use std::future::Future; +use super::rpc::replication::Frame as RpcFrame; pub use libsql_injector::LibsqlInjector; pub use sqlite_injector::SqliteInjector; -use crate::frame::{Frame, FrameNo}; +use crate::frame::FrameNo; pub use error::Error; use error::Result; @@ -16,7 +17,7 @@ pub trait Injector { /// Inject a singular frame. fn inject_frame( &mut self, - frame: Frame, + frame: RpcFrame, ) -> impl Future>> + Send; /// Discard any uncommintted frames. diff --git a/libsql-replication/src/injector/sqlite_injector/mod.rs b/libsql-replication/src/injector/sqlite_injector/mod.rs index 2f4193e469..f6ce2aa89f 100644 --- a/libsql-replication/src/injector/sqlite_injector/mod.rs +++ b/libsql-replication/src/injector/sqlite_injector/mod.rs @@ -7,6 +7,7 @@ use rusqlite::OpenFlags; use tokio::task::spawn_blocking; use crate::frame::{Frame, FrameNo}; +use crate::rpc::replication::Frame as RpcFrame; use self::injector_wal::{ InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, @@ -25,8 +26,10 @@ pub struct SqliteInjector { } impl Injector for SqliteInjector { - async fn inject_frame(&mut self, frame: Frame) -> Result> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result> { let inner = self.inner.clone(); + let frame = + Frame::try_from(&frame.data[..]).map_err(|e| Error::FatalInjectError(e.into()))?; spawn_blocking(move || inner.lock().inject_frame(frame)) .await .unwrap() diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index ee75822676..38cdbf6e7c 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -63,7 +63,7 @@ impl From for Error { #[async_trait::async_trait] pub trait ReplicatorClient { - type FrameStream: Stream> + Unpin + Send; + type FrameStream: Stream> + Unpin + Send; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error>; @@ -318,7 +318,7 @@ where } } - async fn inject_frame(&mut self, frame: Frame) -> Result<(), Error> { + async fn inject_frame(&mut self, frame: RpcFrame) -> Result<(), Error> { self.frames_synced += 1; match self.injector.inject_frame(frame).await? { @@ -360,6 +360,7 @@ mod test { use async_stream::stream; use crate::frame::{FrameBorrowed, FrameMut}; + use crate::rpc::replication::Frame as RpcFrame; use super::*; @@ -370,7 +371,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -414,7 +416,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -456,7 +459,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -500,7 +504,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -544,7 +549,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -586,7 +592,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -627,7 +634,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -672,7 +680,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -740,7 +749,8 @@ mod test { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = + Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -752,15 +762,26 @@ mod test { let frames = self .frames .iter() + .map(|f| RpcFrame { + data: f.bytes(), + timestamp: None, + }) .take(2) - .cloned() .map(Ok) .chain(Some(Err(Error::Client("some client error".into())))) .collect::>(); Ok(Box::pin(tokio_stream::iter(frames))) } else { - let stream = tokio_stream::iter(self.frames.clone().into_iter().map(Ok)); - Ok(Box::pin(stream)) + let iter = self + .frames + .iter() + .map(|f| RpcFrame { + data: f.bytes(), + timestamp: None, + }) + .map(Ok) + .collect::>(); + Ok(Box::pin(tokio_stream::iter(iter))) } } /// Return a snapshot for the current replication index. Called after next_frame has returned a diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index d68c259dc9..89e465053b 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -4,16 +4,17 @@ use std::pin::Pin; use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::TryStreamExt; -use libsql_replication::frame::Frame; use libsql_replication::meta::WalIndexMeta; -use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::replicator::{Error, ReplicatorClient}; use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ - verify_session_token, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, SESSION_TOKEN_KEY, + verify_session_token, Frame as RpcFrame, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, + SESSION_TOKEN_KEY, }; use tokio::sync::watch; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::Stream; + use tonic::metadata::{AsciiMetadataValue, BinaryMetadataValue}; use tonic::transport::Channel; use tonic::{Code, Request, Status}; @@ -95,7 +96,7 @@ impl Client { #[async_trait::async_trait] impl ReplicatorClient for Client { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { @@ -169,7 +170,7 @@ impl ReplicatorClient for Client { None => REPLICATION_LATENCY_CACHE_MISS.increment(1), } }) - .map(map_frame_err); + .map_err(Into::into); Ok(Box::pin(stream)) } @@ -181,7 +182,7 @@ impl ReplicatorClient for Client { let req = self.make_request(offset); match self.client.snapshot(req).await { Ok(resp) => { - let stream = resp.into_inner().map(map_frame_err); + let stream = resp.into_inner().map_err(Into::into); Ok(Box::pin(stream)) } Err(e) if e.code() == Code::Unavailable => Err(Error::SnapshotPending), diff --git a/libsql/src/replication/local_client.rs b/libsql/src/replication/local_client.rs index 2d7b940c92..d3c713f530 100644 --- a/libsql/src/replication/local_client.rs +++ b/libsql/src/replication/local_client.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use futures::{StreamExt, TryStreamExt}; use libsql_replication::{ + rpc::replication::Frame as RpcFrame, frame::{Frame, FrameNo}, meta::WalIndexMeta, replicator::{Error, ReplicatorClient}, @@ -35,7 +36,7 @@ impl LocalClient { #[async_trait::async_trait] impl ReplicatorClient for LocalClient { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { @@ -46,7 +47,7 @@ impl ReplicatorClient for LocalClient { async fn next_frames(&mut self) -> Result { match self.frames.take() { Some(Frames::Vec(f)) => { - let iter = f.into_iter().map(Ok); + let iter = f.into_iter().map(|f| RpcFrame { data: f.bytes(), timestamp: None }).map(Ok); Ok(Box::pin(tokio_stream::iter(iter))) } Some(f @ Frames::Snapshot(_)) => { @@ -70,7 +71,8 @@ impl ReplicatorClient for LocalClient { if s.as_mut().peek().await.is_none() { next.header_mut().size_after = size_after.into(); } - yield Frame::from(next); + let frame = Frame::from(next); + yield RpcFrame { data: frame.bytes(), timestamp: None }; } }; @@ -95,8 +97,9 @@ impl ReplicatorClient for LocalClient { #[cfg(test)] mod test { - use libsql_replication::snapshot::SnapshotFile; + use libsql_replication::{frame::FrameHeader, snapshot::SnapshotFile}; use tempfile::tempdir; + use zerocopy::FromBytes; use super::*; @@ -111,7 +114,8 @@ mod test { let mut s = client.snapshot().await.unwrap(); assert!(matches!(s.next().await, Some(Ok(_)))); let last = s.next().await.unwrap().unwrap(); - assert_eq!(last.header().size_after.get(), 2); + let header: FrameHeader = FrameHeader::read_from_prefix(&last.data[..]).unwrap(); + assert_eq!(header.size_after.get(), 2); assert!(s.next().await.is_none()); } } diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index 79cffb1c38..26e537d18a 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -4,13 +4,13 @@ use std::pin::Pin; use std::time::{Duration, Instant}; use bytes::Bytes; -use futures::StreamExt as _; -use libsql_replication::frame::{Frame, FrameHeader, FrameNo}; +use futures::{StreamExt as _, TryStreamExt}; +use libsql_replication::frame::{FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; -use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient}; +use libsql_replication::replicator::{Error, ReplicatorClient}; use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ - verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, + Frame as RpcFrame, verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; use tokio_stream::Stream; use tonic::metadata::AsciiMetadataValue; @@ -161,7 +161,7 @@ impl RemoteClient { let frames_iter = frames .into_iter() - .map(|f| Frame::try_from(&*f.data).map_err(|e| Error::Client(e.into()))); + .map(Ok); let stream = tokio_stream::iter(frames_iter); @@ -197,7 +197,7 @@ impl RemoteClient { .snapshot(req) .await? .into_inner() - .map(map_frame_err) + .map_err(|e| e.into()) .peekable(); { @@ -205,7 +205,8 @@ impl RemoteClient { // the first frame is the one with the highest frame_no in the snapshot if let Some(Ok(f)) = frames.peek().await { - self.last_received = Some(f.header().frame_no.get()); + let header: FrameHeader = FrameHeader::read_from_prefix(&f.data[..]).unwrap(); + self.last_received = Some(header.frame_no.get()); } } @@ -240,7 +241,7 @@ fn maybe_log( #[async_trait::async_trait] impl ReplicatorClient for RemoteClient { - type FrameStream = Pin> + Send + 'static>>; + type FrameStream = Pin> + Send + 'static>>; /// Perform handshake with remote async fn handshake(&mut self) -> Result<(), Error> { From e97026c82d4bfc13f2298d98c705a19cea6159be Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 8 Aug 2024 16:33:10 +0200 Subject: [PATCH 58/70] feature gate libsql injector --- libsql-replication/Cargo.toml | 3 ++- libsql-replication/src/injector/mod.rs | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 2a03d362bf..068e23a652 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -12,7 +12,7 @@ license = "MIT" tonic = { version = "0.11", features = ["tls"] } prost = "0.12" libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } -libsql-wal = { path = "../libsql-wal/" } +libsql-wal = { path = "../libsql-wal/", optional = true } rusqlite = { workspace = true } parking_lot = "0.12.1" bytes = { version = "1.5.0", features = ["serde"] } @@ -38,3 +38,4 @@ tonic-build = "0.11" [features] encryption = ["libsql-sys/encryption"] +libsql_wal = ["dep:libsql-wal"] diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index 3712458d2f..b139f07cc9 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -1,6 +1,7 @@ use std::future::Future; use super::rpc::replication::Frame as RpcFrame; +#[cfg(feature = "libsql_wal")] pub use libsql_injector::LibsqlInjector; pub use sqlite_injector::SqliteInjector; @@ -10,6 +11,7 @@ pub use error::Error; use error::Result; mod error; +#[cfg(feature = "libsql_wal")] mod libsql_injector; mod sqlite_injector; From d29ca7fabed4440e0eb89eab2b10765b85a38e84 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 10 Aug 2024 11:16:24 +0200 Subject: [PATCH 59/70] fix conflicts --- .../proto/replication_log.proto | 10 +++++----- libsql-replication/src/generated/wal_log.rs | 20 +++++++++---------- libsql-replication/src/rpc.rs | 4 +--- .../src/namespace/configurator/replica.rs | 4 +++- .../src/replication/replicator_client.rs | 6 ++++-- libsql-server/src/rpc/replication_log.rs | 18 +++++++++++------ libsql/src/replication/remote_client.rs | 6 ++++-- 7 files changed, 39 insertions(+), 29 deletions(-) diff --git a/libsql-replication/proto/replication_log.proto b/libsql-replication/proto/replication_log.proto index b3be419319..b358232705 100644 --- a/libsql-replication/proto/replication_log.proto +++ b/libsql-replication/proto/replication_log.proto @@ -5,18 +5,18 @@ import "metadata.proto"; message LogOffset { uint64 next_offset = 1; -} - -message HelloRequest { - optional uint64 handshake_version = 1; enum WalFlavor { Sqlite = 0; Libsql = 1; } - // the type of wal that the client is expecting + // the type of wal frames that the client is expecting optional WalFlavor wal_flavor = 2; } +message HelloRequest { + optional uint64 handshake_version = 1; +} + message HelloResponse { /// Uuid of the current generation string generation_id = 1; diff --git a/libsql-replication/src/generated/wal_log.rs b/libsql-replication/src/generated/wal_log.rs index 441881c4a7..a34d5e59dd 100644 --- a/libsql-replication/src/generated/wal_log.rs +++ b/libsql-replication/src/generated/wal_log.rs @@ -4,18 +4,12 @@ pub struct LogOffset { #[prost(uint64, tag = "1")] pub next_offset: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HelloRequest { - #[prost(uint64, optional, tag = "1")] - pub handshake_version: ::core::option::Option, - /// the type of wal that the client is expecting - #[prost(enumeration = "hello_request::WalFlavor", optional, tag = "2")] + /// the type of wal frames that the client is expecting + #[prost(enumeration = "log_offset::WalFlavor", optional, tag = "2")] pub wal_flavor: ::core::option::Option, } -/// Nested message and enum types in `HelloRequest`. -pub mod hello_request { +/// Nested message and enum types in `LogOffset`. +pub mod log_offset { #[derive( Clone, Copy, @@ -55,6 +49,12 @@ pub mod hello_request { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct HelloRequest { + #[prost(uint64, optional, tag = "1")] + pub handshake_version: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct HelloResponse { /// / Uuid of the current generation #[prost(string, tag = "1")] diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index 8e1165af65..a538bc4c28 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -26,7 +26,6 @@ pub mod replication { use uuid::Uuid; - use self::hello_request::WalFlavor; include!("generated/wal_log.rs"); pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO"; @@ -48,10 +47,9 @@ pub mod replication { } impl HelloRequest { - pub fn new(wal_flavor: WalFlavor) -> Self { + pub fn new() -> Self { Self { handshake_version: Some(1), - wal_flavor: Some(wal_flavor.into()), } } } diff --git a/libsql-server/src/namespace/configurator/replica.rs b/libsql-server/src/namespace/configurator/replica.rs index 84ebadb897..7832d30ef8 100644 --- a/libsql-server/src/namespace/configurator/replica.rs +++ b/libsql-server/src/namespace/configurator/replica.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use futures::Future; use hyper::Uri; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use tokio::task::JoinSet; use tonic::transport::Channel; @@ -68,10 +69,11 @@ impl ConfigureNamespace for ReplicaConfigurator { &db_path, meta_store_handle.clone(), store.clone(), + WalFlavor::Sqlite, ) .await?; let applied_frame_no_receiver = client.current_frame_no_notifier.subscribe(); - let mut replicator = libsql_replication::replicator::Replicator::new( + let mut replicator = libsql_replication::replicator::Replicator::new_sqlite( client, db_path.join("data"), DEFAULT_AUTO_CHECKPOINT, diff --git a/libsql-server/src/replication/replicator_client.rs b/libsql-server/src/replication/replicator_client.rs index 89e465053b..753baac996 100644 --- a/libsql-server/src/replication/replicator_client.rs +++ b/libsql-server/src/replication/replicator_client.rs @@ -6,7 +6,7 @@ use chrono::{DateTime, Utc}; use futures::TryStreamExt; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{Error, ReplicatorClient}; -use libsql_replication::rpc::replication::hello_request::WalFlavor; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient; use libsql_replication::rpc::replication::{ verify_session_token, Frame as RpcFrame, HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, @@ -101,7 +101,7 @@ impl ReplicatorClient for Client { #[tracing::instrument(skip(self))] async fn handshake(&mut self) -> Result<(), Error> { tracing::debug!("Attempting to perform handshake with primary."); - let req = self.make_request(HelloRequest::new(self.wal_flavor)); + let req = self.make_request(HelloRequest::new()); let resp = self.client.hello(req).await?; let hello = resp.into_inner(); verify_session_token(&hello.session_token).map_err(Error::Client)?; @@ -143,6 +143,7 @@ impl ReplicatorClient for Client { async fn next_frames(&mut self) -> Result { let offset = LogOffset { next_offset: self.next_frame_no(), + wal_flavor: Some(self.wal_flavor.into()), }; let req = self.make_request(offset); let stream = self @@ -178,6 +179,7 @@ impl ReplicatorClient for Client { async fn snapshot(&mut self) -> Result { let offset = LogOffset { next_offset: self.next_frame_no(), + wal_flavor: Some(self.wal_flavor.into()), }; let req = self.make_request(offset); match self.client.snapshot(req).await { diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index bf1840a5c6..1ef306daf1 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -8,7 +8,7 @@ use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures_core::Future; pub use libsql_replication::rpc::replication as rpc; -use libsql_replication::rpc::replication::hello_request::WalFlavor; +use libsql_replication::rpc::replication::log_offset::WalFlavor; use libsql_replication::rpc::replication::replication_log_server::ReplicationLog; use libsql_replication::rpc::replication::{ Frame, Frames, HelloRequest, HelloResponse, LogOffset, NAMESPACE_DOESNT_EXIST, @@ -260,6 +260,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } let namespace = super::extract_namespace(self.disable_namespaces, &req)?; self.authenticate(&req, namespace.clone()).await?; @@ -305,6 +308,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } let namespace = super::extract_namespace(self.disable_namespaces, &req)?; self.authenticate(&req, namespace.clone()).await?; @@ -355,11 +361,6 @@ impl ReplicationLog for ReplicationLogService { guard.insert((replica_addr, namespace.clone())); } } - - if let WalFlavor::Libsql = req.get_ref().wal_flavor() { - return Err(Status::invalid_argument("libsql wal not supported")); - } - let (logger, config, version, _, _) = self.logger_from_namespace(namespace, &req, false).await?; @@ -381,7 +382,12 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { + if let WalFlavor::Libsql = req.get_ref().wal_flavor() { + return Err(Status::invalid_argument("libsql wal not supported")); + } + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; let (logger, _, _, stats, _) = self.logger_from_namespace(namespace, &req, true).await?; diff --git a/libsql/src/replication/remote_client.rs b/libsql/src/replication/remote_client.rs index 26e537d18a..864392ddb5 100644 --- a/libsql/src/replication/remote_client.rs +++ b/libsql/src/replication/remote_client.rs @@ -8,7 +8,6 @@ use futures::{StreamExt as _, TryStreamExt}; use libsql_replication::frame::{FrameHeader, FrameNo}; use libsql_replication::meta::WalIndexMeta; use libsql_replication::replicator::{Error, ReplicatorClient}; -use libsql_replication::rpc::replication::hello_request::WalFlavor; use libsql_replication::rpc::replication::{ Frame as RpcFrame, verify_session_token, Frames, HelloRequest, HelloResponse, LogOffset, SESSION_TOKEN_KEY, }; @@ -117,9 +116,10 @@ impl RemoteClient { self.dirty = false; } let prefetch = self.session_token.is_some(); - let hello_req = self.make_request(HelloRequest::new(WalFlavor::Sqlite)); + let hello_req = self.make_request(HelloRequest::new()); let log_offset_req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); let mut client_clone = self.remote.clone(); let hello_fut = time(async { @@ -179,6 +179,7 @@ impl RemoteClient { None => { let req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); time(self.remote.replication.batch_log_entries(req)).await } @@ -190,6 +191,7 @@ impl RemoteClient { async fn do_snapshot(&mut self) -> Result<::FrameStream, Error> { let req = self.make_request(LogOffset { next_offset: self.next_offset(), + wal_flavor: None, }); let mut frames = self .remote From e7a24304d9378806561b9cfa2093c184a0d6fc4d Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 12:53:11 +0400 Subject: [PATCH 60/70] fix potential memory leak --- libsql-sqlite3/src/vectordiskann.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index ae832f4400..caaeee64b2 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -954,6 +954,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -965,10 +970,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; From 4c38e5f11e75c7debc8caf809daab28db0f5a1eb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 12:58:50 +0400 Subject: [PATCH 61/70] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 9 +++++---- libsql-ffi/bundled/src/sqlite3.c | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 6c60bccaab..b68e262935 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212717,6 +212717,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -212728,10 +212733,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 6c60bccaab..b68e262935 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212717,6 +212717,11 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){ } static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){ + if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ + return SQLITE_NOMEM_BKPT; + } + loadVectorPair(&pCtx->query, pQuery); + pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double)); pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*)); pCtx->nCandidates = 0; @@ -212728,10 +212733,6 @@ static int diskAnnSearchCtxInit(const DiskAnnIndex *pIndex, DiskAnnSearchCtx *pC pCtx->visitedList = NULL; pCtx->nUnvisited = 0; pCtx->blobMode = blobMode; - if( initVectorPair(pIndex->nNodeVectorType, pIndex->nEdgeVectorType, pIndex->nVectorDims, &pCtx->query) != 0 ){ - return SQLITE_NOMEM_BKPT; - } - loadVectorPair(&pCtx->query, pQuery); if( pCtx->aDistances != NULL && pCtx->aCandidates != NULL && pCtx->aTopDistances != NULL && pCtx->aTopCandidates != NULL ){ return SQLITE_OK; From 90007425ba039965f2cc7cfd7d01f25b35bc1869 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 18:54:39 +0400 Subject: [PATCH 62/70] add simple pragma test --- libsql-sqlite3/test/libsql_vector_index.test | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 7819ba5076..281943e7c2 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -30,12 +30,17 @@ set testprefix vector sqlite3_db_config_lookaside db 0 0 0 -do_execsql_test vector-integrity { - CREATE TABLE t_integrity( v FLOAT32(3) ); - CREATE INDEX t_integrity_idx ON t_integrity( libsql_vector_idx(v) ); - INSERT INTO t_integrity VALUES (vector('[1,2,3]')); +do_execsql_test vector-pragmas { + CREATE TABLE t_pragmas( v FLOAT32(3) ); + CREATE INDEX t_pragmas_idx ON t_pragmas( libsql_vector_idx(v) ); + INSERT INTO t_pragmas VALUES (vector('[1,2,3]')); PRAGMA integrity_check; -} {{row 1 missing from index t_integrity_idx} {wrong # of entries in index t_integrity_idx}} + PRAGMA index_list='t_pragmas'; +} { + {row 1 missing from index t_pragmas_idx} + {wrong # of entries in index t_pragmas_idx} + 0 t_pragmas_idx 0 c 0 +} do_execsql_test vector-typename { CREATE TABLE t_type_spaces( v FLOAT32 ( 3 ) ); From e8e5870dfa54fd6d64993928333dfa95ab604213 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 18:54:12 +0400 Subject: [PATCH 63/70] don't change idxType as sqlite rely on it pretty much --- libsql-sqlite3/src/build.c | 9 ++++----- libsql-sqlite3/src/parse.y | 3 --- libsql-sqlite3/src/sqliteInt.h | 9 +++------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/libsql-sqlite3/src/build.c b/libsql-sqlite3/src/build.c index 4396f2fae7..d6faa6b071 100644 --- a/libsql-sqlite3/src/build.c +++ b/libsql-sqlite3/src/build.c @@ -833,7 +833,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -4345,13 +4345,12 @@ void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -4399,7 +4398,7 @@ void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -4680,7 +4679,7 @@ void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists){ pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; diff --git a/libsql-sqlite3/src/parse.y b/libsql-sqlite3/src/parse.y index 41e08ad6d6..f866ec5d2c 100644 --- a/libsql-sqlite3/src/parse.y +++ b/libsql-sqlite3/src/parse.y @@ -1451,9 +1451,6 @@ paren_exprlist(A) ::= LP exprlist(X) RP. {A = X;} cmd ::= createkw(S) uniqueflag(U) INDEX ifnotexists(NE) nm(X) dbnm(D) indextype(T) ON nm(Y) LP sortlist(Z) RP where_opt(W). { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( T.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &X, &D, sqlite3SrcListAppend(pParse,0,&Y,0), Z, U, &S, W, SQLITE_SO_ASC, NE, idxType, T.pUsing); diff --git a/libsql-sqlite3/src/sqliteInt.h b/libsql-sqlite3/src/sqliteInt.h index e2fd32d3c4..0a9dd98d66 100644 --- a/libsql-sqlite3/src/sqliteInt.h +++ b/libsql-sqlite3/src/sqliteInt.h @@ -2799,7 +2799,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -2831,7 +2832,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -2840,10 +2840,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: From b1dbaa13bb66bf1aac0f078489de811507a1f630 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Sun, 11 Aug 2024 19:22:44 +0400 Subject: [PATCH 64/70] build bundles --- .../SQLite3MultipleCiphers/src/sqlite3.c | 21 +++++++------------ libsql-ffi/bundled/src/sqlite3.c | 21 +++++++------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 1b670120d5..8ceabfc713 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -19266,7 +19266,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -19298,7 +19299,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -19307,10 +19307,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: @@ -123188,7 +123185,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -126700,13 +126697,12 @@ SQLITE_PRIVATE void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -126754,7 +126750,7 @@ SQLITE_PRIVATE void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -127035,7 +127031,7 @@ SQLITE_PRIVATE void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; @@ -177910,9 +177906,6 @@ static YYACTIONTYPE yy_reduce( case 242: /* cmd ::= createkw uniqueflag INDEX ifnotexists nm dbnm indextype ON nm LP sortlist RP where_opt */ { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( yymsp[-6].minor.yy421.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &yymsp[-8].minor.yy0, &yymsp[-7].minor.yy0, sqlite3SrcListAppend(pParse,0,&yymsp[-4].minor.yy0,0), yymsp[-2].minor.yy402, yymsp[-11].minor.yy502, &yymsp[-12].minor.yy0, yymsp[0].minor.yy590, SQLITE_SO_ASC, yymsp[-9].minor.yy502, idxType, yymsp[-6].minor.yy421.pUsing); diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 1b670120d5..8ceabfc713 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -19266,7 +19266,8 @@ struct Index { u16 nKeyCol; /* Number of columns forming the key */ u16 nColumn; /* Number of columns stored in the index */ u8 onError; /* OE_Abort, OE_Ignore, OE_Replace, or OE_None */ - unsigned idxType:3; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK, 4:VECTOR INDEX */ + unsigned idxType:2; /* 0:Normal 1:UNIQUE, 2:PRIMARY KEY, 3:IPK */ + unsigned idxIsVector:1; /* 0:Normal 1:VECTOR INDEX */ unsigned bUnordered:1; /* Use this index for == or IN queries only */ unsigned uniqNotNull:1; /* True if UNIQUE and NOT NULL for all columns */ unsigned isResized:1; /* True if resizeIndexObject() has been called */ @@ -19298,7 +19299,6 @@ struct Index { #define SQLITE_IDXTYPE_UNIQUE 1 /* Implements a UNIQUE constraint */ #define SQLITE_IDXTYPE_PRIMARYKEY 2 /* Is the PRIMARY KEY for the table */ #define SQLITE_IDXTYPE_IPK 3 /* INTEGER PRIMARY KEY index */ -#define SQLITE_IDXTYPE_VECTOR 4 /* libSQL vector index */ /* Return true if index X is a PRIMARY KEY index */ #define IsPrimaryKeyIndex(X) ((X)->idxType==SQLITE_IDXTYPE_PRIMARYKEY) @@ -19307,10 +19307,7 @@ struct Index { #define IsUniqueIndex(X) ((X)->onError!=OE_None) /* Return true if index X is a vector index */ -#define IsVectorIndex(X) ((X)->idxType==SQLITE_IDXTYPE_VECTOR) - -/* Return true if index X is an user defined index (APPDEF or VECTOR) */ -#define IsAppDefIndex(X) ((X)->idxType==SQLITE_IDXTYPE_APPDEF||(X)->idxType==SQLITE_IDXTYPE_VECTOR) +#define IsVectorIndex(X) ((X)->idxIsVector==1) /* The Index.aiColumn[] values are normally positive integer. But ** there are some negative values that have special meaning: @@ -123188,7 +123185,7 @@ static void SQLITE_NOINLINE deleteTable(sqlite3 *db, Table *pTable){ for(pIndex = pTable->pIndex; pIndex; pIndex=pNext){ pNext = pIndex->pNext; assert( pIndex->pSchema==pTable->pSchema - || (IsVirtual(pTable) && !IsAppDefIndex(pIndex)) ); + || (IsVirtual(pTable) && pIndex->idxType!=SQLITE_IDXTYPE_APPDEF) ); if( db->pnBytesFreed==0 && !IsVirtual(pTable) ){ char *zName = pIndex->zName; TESTONLY ( Index *pOld = ) sqlite3HashInsert( @@ -126700,13 +126697,12 @@ SQLITE_PRIVATE void sqlite3CreateIndex( goto exit_create_index; } if( vectorIdxRc >= 1 ){ - idxType = SQLITE_IDXTYPE_VECTOR; /* * SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index) * But, SQLite pretty conservative about usage of unordered indices - that's what we need here */ pIndex->bUnordered = 1; - pIndex->idxType = idxType; + pIndex->idxIsVector = 1; } if( vectorIdxRc == 1 ){ skipRefill = 1; @@ -126754,7 +126750,7 @@ SQLITE_PRIVATE void sqlite3CreateIndex( for(pIdx=pTab->pIndex; pIdx; pIdx=pIdx->pNext){ int k; assert( IsUniqueIndex(pIdx) ); - assert( !IsAppDefIndex(pIdx) ); + assert( pIdx->idxType!=SQLITE_IDXTYPE_APPDEF ); assert( IsUniqueIndex(pIndex) ); if( pIdx->nKeyCol!=pIndex->nKeyCol ) continue; @@ -127035,7 +127031,7 @@ SQLITE_PRIVATE void sqlite3DropIndex(Parse *pParse, SrcList *pName, int ifExists pParse->checkSchema = 1; goto exit_drop_index; } - if( !IsAppDefIndex(pIndex) ){ + if( pIndex->idxType!=SQLITE_IDXTYPE_APPDEF ){ sqlite3ErrorMsg(pParse, "index associated with UNIQUE " "or PRIMARY KEY constraint cannot be dropped", 0); goto exit_drop_index; @@ -177910,9 +177906,6 @@ static YYACTIONTYPE yy_reduce( case 242: /* cmd ::= createkw uniqueflag INDEX ifnotexists nm dbnm indextype ON nm LP sortlist RP where_opt */ { u8 idxType = SQLITE_IDXTYPE_APPDEF; - if( yymsp[-6].minor.yy421.pUsing!=0 ){ - idxType = SQLITE_IDXTYPE_VECTOR; - } sqlite3CreateIndex(pParse, &yymsp[-8].minor.yy0, &yymsp[-7].minor.yy0, sqlite3SrcListAppend(pParse,0,&yymsp[-4].minor.yy0,0), yymsp[-2].minor.yy402, yymsp[-11].minor.yy502, &yymsp[-12].minor.yy0, yymsp[0].minor.yy590, SQLITE_SO_ASC, yymsp[-9].minor.yy502, idxType, yymsp[-6].minor.yy421.pUsing); From cea87253fbfdd5e6cdf5cb245384b5ebbb4e1044 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:43:51 +0400 Subject: [PATCH 65/70] fix DELETE from vector index as there can be no row due to the NULL value of the vector --- libsql-sqlite3/src/vectordiskann.c | 6 +++++- libsql-sqlite3/test/libsql_vector_index.test | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 1f9973d38c..95276f8887 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1633,7 +1633,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index 281943e7c2..dd1f53d0df 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -135,9 +135,19 @@ do_execsql_test vector-null { CREATE INDEX t_null_idx ON t_null( libsql_vector_idx(v) ); INSERT INTO t_null VALUES(vector('[1,2,3]')); INSERT INTO t_null VALUES(NULL); - INSERT INTO t_null VALUES(vector('[2,3,4]')); + INSERT INTO t_null VALUES(vector('[3,4,5]')); SELECT * FROM vector_top_k('t_null_idx', '[1,2,3]', 2); -} {1 3} + UPDATE t_null SET v = vector('[2,3,4]') WHERE rowid = 2; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + UPDATE t_null SET v = NULL WHERE rowid = 3; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + UPDATE t_null SET v = NULL; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); +} { + 1 3 + 2 3 1 + 2 1 +} do_execsql_test vector-sql { CREATE TABLE t_sql( v FLOAT32(3)); From ac60c100d25a7e80ef8aa8a5df775c126217e0f8 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:46:05 +0400 Subject: [PATCH 66/70] build bundles --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 6 +++++- libsql-ffi/bundled/src/sqlite3.c | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 8ceabfc713..887d3322b9 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -213389,7 +213389,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 8ceabfc713..887d3322b9 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -213389,7 +213389,11 @@ int diskAnnDelete( DiskAnnTrace(("diskAnnDelete started: rowid=%lld\n", nodeRowid)); rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); - if( rc != SQLITE_OK ){ + if( rc == DISKANN_ROW_NOT_FOUND ){ + // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + rc = SQLITE_OK; + goto out; + }else if( rc != SQLITE_OK ){ *pzErrMsg = sqlite3_mprintf("vector index(delete): failed to create blob for node row"); goto out; } From 405c71073636dea35feed1cf79d150ee6a835493 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 11:56:04 +0400 Subject: [PATCH 67/70] refine comment --- libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c | 3 ++- libsql-ffi/bundled/src/sqlite3.c | 3 ++- libsql-sqlite3/src/vectordiskann.c | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 887d3322b9..1fd0252f0d 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -213390,7 +213390,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 887d3322b9..1fd0252f0d 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -213390,7 +213390,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index 95276f8887..a6c279b259 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -1634,7 +1634,8 @@ int diskAnnDelete( rc = blobSpotCreate(pIndex, &pNodeBlob, nodeRowid, pIndex->nBlockSize, DISKANN_BLOB_WRITABLE); if( rc == DISKANN_ROW_NOT_FOUND ){ - // we omit rows with NULL values so it can be the case that there is nothing to delete in the index while row exists in the base table + // as we omit rows with NULL values during insert, it can be the case that there is nothing to delete in the index, while row exists in the base table + // so, we must simply silently stop delete process as there is nothing to delete from index rc = SQLITE_OK; goto out; }else if( rc != SQLITE_OK ){ From c135391f60d30c671d164119216ad8545d3768eb Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 12:16:52 +0400 Subject: [PATCH 68/70] fix unstable test --- libsql-sqlite3/test/libsql_vector_index.test | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index dd1f53d0df..c872a20ae0 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -131,21 +131,21 @@ do_execsql_test vector-empty { do_execsql_test vector-null { - CREATE TABLE t_null( v FLOAT32(3)); + CREATE TABLE t_null( v FLOAT32(2)); CREATE INDEX t_null_idx ON t_null( libsql_vector_idx(v) ); - INSERT INTO t_null VALUES(vector('[1,2,3]')); + INSERT INTO t_null VALUES(vector('[1,-1]')); INSERT INTO t_null VALUES(NULL); - INSERT INTO t_null VALUES(vector('[3,4,5]')); - SELECT * FROM vector_top_k('t_null_idx', '[1,2,3]', 2); - UPDATE t_null SET v = vector('[2,3,4]') WHERE rowid = 2; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + INSERT INTO t_null VALUES(vector('[-2,1]')); + SELECT * FROM vector_top_k('t_null_idx', '[1,1]', 2); + UPDATE t_null SET v = vector('[1,1]') WHERE rowid = 2; + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); UPDATE t_null SET v = NULL WHERE rowid = 3; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); UPDATE t_null SET v = NULL; - SELECT rowid FROM vector_top_k('t_null_idx', vector('[2,3,4]'), 3); + SELECT rowid FROM vector_top_k('t_null_idx', vector('[1,1]'), 3); } { 1 3 - 2 3 1 + 2 1 3 2 1 } From 2d325ba8f6c7d74cf3733e134ce952e30f7961c7 Mon Sep 17 00:00:00 2001 From: Nikita Sivukhin Date: Mon, 12 Aug 2024 18:46:14 +0400 Subject: [PATCH 69/70] windows compiler complains about operations with void* pointers - error C2036: 'void *': unknown size ... --- .../bundled/SQLite3MultipleCiphers/src/sqlite3.c | 16 ++++++++-------- libsql-ffi/bundled/src/sqlite3.c | 16 ++++++++-------- libsql-sqlite3/src/vectordiskann.c | 16 ++++++++-------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 1fd0252f0d..2dcd939f21 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -212669,7 +212669,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -212687,7 +212687,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -212850,8 +212850,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -212872,8 +212872,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -212881,8 +212881,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 1fd0252f0d..2dcd939f21 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -212669,7 +212669,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -212687,7 +212687,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -212850,8 +212850,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -212872,8 +212872,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -212881,8 +212881,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node diff --git a/libsql-sqlite3/src/vectordiskann.c b/libsql-sqlite3/src/vectordiskann.c index a6c279b259..7b29c6f50f 100644 --- a/libsql-sqlite3/src/vectordiskann.c +++ b/libsql-sqlite3/src/vectordiskann.c @@ -913,7 +913,7 @@ int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, fl return nSize < nMaxSize ? nSize : -1; } -void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) { +void bufferInsert(u8 *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const u8 *pItem, u8 *pLast) { int itemsToMove; assert( nMaxSize > 0 && nItemSize > 0 ); @@ -931,7 +931,7 @@ void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItem memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize); } -void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) { +void bufferDelete(u8 *aBuffer, int nSize, int iDelete, int nItemSize) { int itemsToMove; assert( nItemSize > 0 ); @@ -1094,8 +1094,8 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo if( iInsert < 0 ){ return; } - bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL); - bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pNode, NULL); + bufferInsert((u8*)pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates); } @@ -1116,8 +1116,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL ); diskAnnNodeFree(pCtx->aCandidates[iDelete]); - bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); - bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); + bufferDelete((u8*)pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*)); + bufferDelete((u8*)pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float)); pCtx->nCandidates--; pCtx->nUnvisited--; @@ -1125,8 +1125,8 @@ static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete) static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){ DiskAnnNode *pLast = NULL; - bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast); - bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL); + bufferInsert((u8*)pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), (u8*)&pCandidate, (u8*)&pLast); + bufferInsert((u8*)pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), (u8*)&distance, NULL); pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates); if( pLast != NULL && !pLast->visited ){ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node From 2f1d71d5670a19d8449fecec5a294706db1a59d6 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Mon, 12 Aug 2024 15:39:41 -0400 Subject: [PATCH 70/70] libsql: add `tls` feature This adds a new `tls` feature that is enabled by default, if this feature is disabled building a libsql connection will panic with a message asking you to configure a http connector. This allows users to bring their own http connector and more importantly their own TLS lib with their own versions without needing to compile rustls which we use by default. This resolves solana-sdk >2 build issues with uses an older version of `curve25519-dalek` that pings `zeroize` to `<1.4`. New versions of `rustls` require `1.7` of `zeroize` thus causing issues when building `rustls` for libsql with the `tls` feature. --- Cargo.lock | 2 -- libsql-replication/Cargo.toml | 2 +- libsql/Cargo.toml | 13 ++++++++----- libsql/src/database.rs | 13 ++++++++++++- libsql/src/lib.rs | 24 ++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e19e03e9a..cb39f33dda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6251,7 +6251,6 @@ dependencies = [ "percent-encoding", "pin-project", "prost", - "rustls-native-certs 0.7.1", "rustls-pemfile 2.1.2", "rustls-pki-types", "tokio", @@ -6261,7 +6260,6 @@ dependencies = [ "tower-layer", "tower-service", "tracing", - "webpki-roots 0.26.3", ] [[package]] diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 068e23a652..27e3d797bb 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tonic = { version = "0.11", features = ["tls"] } +tonic = { version = "0.11", default-features = false, features = ["codegen", "prost"] } prost = "0.12" libsql-sys = { version = "0.7", path = "../libsql-sys", default-features = false, features = ["wal", "rusqlite", "api"] } libsql-wal = { path = "../libsql-wal/", optional = true } diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index 3d65f71c73..98112e4959 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -16,7 +16,7 @@ libsql-hrana = { version = "0.2", path = "../libsql-hrana", optional = true } tokio = { version = "1.29.1", features = ["sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } parking_lot = { version = "0.12.1", optional = true } -hyper = { workspace = true, features = ["client", "stream"], optional = true } +hyper = { version = "0.14", features = ["client", "http1", "http2", "stream", "runtime"], optional = true } hyper-rustls = { version = "0.25", features = ["webpki-roots"], optional = true } base64 = { version = "0.21", optional = true } serde = { version = "1", features = ["derive"], optional = true } @@ -31,7 +31,7 @@ anyhow = { version = "1.0.71", optional = true } bytes = { version = "1.4.0", features = ["serde"], optional = true } uuid = { version = "1.4.0", features = ["v4", "serde"], optional = true } tokio-stream = { version = "0.1.14", optional = true } -tonic = { version = "0.11", features = ["tls", "tls-roots", "tls-webpki-roots"], optional = true} +tonic = { version = "0.11", optional = true} tonic-web = { version = "0.11", optional = true } tower-http = { version = "0.4.4", features = ["trace", "set-header", "util"], optional = true } http = { version = "0.2", optional = true } @@ -53,7 +53,7 @@ tempfile = { version = "3.7.0" } rand = "0.8.5" [features] -default = ["core", "replication", "remote"] +default = ["core", "replication", "remote", "tls"] core = [ "libsql-sys", "dep:bitflags", @@ -88,7 +88,6 @@ replication = [ "dep:tonic", "dep:tonic-web", "dep:tower-http", - "dep:hyper-rustls", "dep:futures", "dep:libsql_replication", ] @@ -109,11 +108,11 @@ remote = [ "hrana", "dep:tower", "dep:hyper", + "dep:hyper", "dep:http", "dep:tokio", "dep:futures", "dep:bitflags", - "dep:hyper-rustls", ] wasm = ["hrana"] cloudflare = [ @@ -121,6 +120,7 @@ cloudflare = [ "dep:worker" ] encryption = ["core", "libsql-sys/encryption", "dep:bytes"] +tls = ["dep:hyper-rustls"] [[bench]] name = "benchmark" @@ -128,3 +128,6 @@ harness = false [package.metadata.docs.rs] rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.cargo-udeps.ignore] +normal = ["hyper-rustls"] diff --git a/libsql/src/database.rs b/libsql/src/database.rs index e87def367d..67162f60b5 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -582,7 +582,10 @@ impl Database { } } -#[cfg(any(feature = "replication", feature = "remote"))] +#[cfg(any( + all(feature = "tls", feature = "replication"), + all(feature = "tls", feature = "remote") +))] fn connector() -> Result> { let mut http = hyper::client::HttpConnector::new(); http.enforce_http(false); @@ -596,6 +599,14 @@ fn connector() -> Result Result { + panic!("The `tls` feature is disabled, you must provide your own http connector"); +} + impl std::fmt::Debug for Database { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Database").finish() diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index b74a3b954f..ac9d500596 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -88,8 +88,32 @@ //! that will allow you to sync you remote database locally. //! - `remote` this feature flag only includes HTTP code that will allow you to run queries against //! a remote database. +//! - `tls` this feature flag disables the builtin TLS connector and instead requires that you pass +//! your own connector for any of the features that require HTTP. #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr( + all( + any( + not(feature = "remote"), + not(feature = "replication"), + not(feature = "core") + ), + feature = "tls" + ), + allow(unused_imports) +)] +#![cfg_attr( + all( + any( + not(feature = "remote"), + not(feature = "replication"), + not(feature = "core") + ), + feature = "tls" + ), + allow(dead_code) +)] #[macro_use] mod macros;