diff --git a/Cargo.lock b/Cargo.lock index 04ae728065..19b0c3bd3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3449,6 +3449,7 @@ dependencies = [ "libsql-sqlite3-parser", "libsql-sys", "libsql_replication", + "libsql_sync", "parking_lot", "pprof", "rand", @@ -3581,6 +3582,7 @@ dependencies = [ "libsql-sys", "libsql-wal", "libsql_replication", + "libsql_sync", "md-5", "metrics", "metrics-exporter-prometheus", @@ -3795,6 +3797,17 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "libsql_sync" +version = "0.4.0" +dependencies = [ + "libsql-sys", + "prost", + "prost-build", + "tonic 0.11.0", + "tonic-build 0.11.0", +] + [[package]] name = "linked-hash-map" version = "0.5.6" diff --git a/Cargo.toml b/Cargo.toml index 92487ecdd0..83408d1882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "libsql-wal", "libsql-storage", "libsql-storage-server", + "libsql-sync", ] exclude = [ diff --git a/docs/WAL_SYNC.md b/docs/WAL_SYNC.md new file mode 100644 index 0000000000..a24637a4b2 --- /dev/null +++ b/docs/WAL_SYNC.md @@ -0,0 +1,79 @@ +# libSQL Sync Protocol Specification + +## Overview + +This is a protocol for supporting offline writes by allowing a database instance to sync its write-ahead log between clients and a remote server. + +## Operations + +### PushWAL + +Push the local WAL to a remote server. + +**Request:** + +- `database_id`: The ID of the database. +- `checkpoint_seq_num`: The current checkpoint sequence number. +- `frame_num_start`: The number of the first frame to push. +- `frame_num_end`: The number of the first frame to push. +- `frames`: The WAL frames to push. + +**Response:** + +- `status`: SUCCESS, CONFLICT, ERROR, or NEED_FULL_SYNC +- `durable_frame_num`: The highest frame number the server acknowledges as durable. + +A client uses the `PushWAL` operation to push its local WAL to the remote server. The operation is idempotent on frames, which means it is safe for the client to send the same frames multiple times. If the server already has them, it ignores them. As an optimization, the client can keep track of durable checkpoint sequence and frame number tuple acknowledged by a remote server to prevent sending duplicate frames. + +**TODO:** + +- Return remote WAL on conflict if client requests it. +- Allow client to request server to perform checkpointing. +- Checksum support in the WAL frames. + +### PullWAL + +Retrieve new WAL frames from the remote server. + +**Request**: + +- `database_id`: The ID of the database. +- `checkpoint_seq_num`: The current checkpoint sequence number. +- `max_frame_num`: The highest frame number in the local WAL. + +**Response**: +- `status`: SUCCESS, CONFLICT, ERROR, or NEED_FULL_SYNC +- `frames`: List of new WAL frames + +### FetchDatabase + +Retrieve the full database file from the server. + +**Request**: + +- `database_id`: The ID of the database. + +**Response**: + +- Stream of database chunks + +A client uses the `FetchDatabase` operation to bootstrap a database file locally and also for disaster recovery. + +## Checkpointing Process + +1. Client may request a checkpoint during PushWAL. +2. Server decides whether to initiate a checkpoint based on its state and the client's request. +3. If checkpoint is needed, server sets `perform_checkpoint` to true in the PushWAL response. +4. Client performs local checkpoint up to `checkpoint_frame_id` if instructed. +5. Server performs its own checkpoint after sending the response. + +## Conflict Resolution + +- The server returns `CONFLICT` error if the WAL on remote is more up-to-date than the client. +- The server sends its current WAL in the response for the client to merge and retry the push. + +## Bootstrapping + +1. New clients start by calling `FetchDatabase` to get the full database file. +2. Follow up with PullWAL to get any new changes since the database file was generated. +3. Apply received WAL frames to the database file to reach the current state. diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index c22f35046f..a282adb4b1 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -10935,6 +10935,10 @@ SQLITE_API int sqlite3_preupdate_blobwrite(sqlite3 *); */ SQLITE_API void *libsql_close_hook(sqlite3 *db, void (*xClose)(void *pCtx, sqlite3 *db), void *arg); +SQLITE_API int libsql_wal_frame_count(sqlite3*, unsigned int*); + +SQLITE_API int libsql_wal_get_frame(sqlite3*, unsigned int, void*, unsigned int); + /* ** CAPI3REF: Low-level system error code ** METHOD: sqlite3 @@ -13960,6 +13964,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); @@ -16373,6 +16378,9 @@ SQLITE_PRIVATE int sqlite3PagerReadFileheader(Pager*, int, unsigned char*); SQLITE_PRIVATE void sqlite3PagerSetBusyHandler(Pager*, int(*)(void *), void *); SQLITE_PRIVATE int sqlite3PagerSetPagesize(Pager*, u32*, int); SQLITE_PRIVATE Pgno sqlite3PagerMaxPageCount(Pager*, Pgno); +SQLITE_PRIVATE unsigned int sqlite3PagerWalFrameCount(Pager *); +SQLITE_PRIVATE int sqlite3PagerWalReadFrame(Pager *, unsigned int, void *, unsigned int); + SQLITE_PRIVATE void sqlite3PagerSetCachesize(Pager*, int); SQLITE_PRIVATE int sqlite3PagerSetSpillsize(Pager*, int); SQLITE_PRIVATE void sqlite3PagerSetMmapLimit(Pager *, sqlite3_int64); @@ -57268,6 +57276,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); @@ -65212,6 +65221,33 @@ SQLITE_PRIVATE int sqlite3PagerCloseWal(Pager *pPager, sqlite3 *db){ return rc; } +SQLITE_PRIVATE unsigned int sqlite3PagerWalFrameCount(Pager *pPager){ + if( pagerUseWal(pPager) ){ + // TODO: We are under sqlite3 mutex, but do we need something else? + struct sqlite3_wal* pWal = (void*) pPager->wal->pData; + return pWal->hdr.mxFrame; + }else{ + return 0; + } +} + +SQLITE_PRIVATE int sqlite3PagerWalReadFrameRaw( + Pager *pPager, + unsigned int iFrame, + void *pFrameOut, + unsigned int nFrameOutLen +){ + if( pagerUseWal(pPager) ){ + unsigned int nFrameLen = 24+pPager->pageSize; + if( nFrameOutLen!=nFrameLen ) return SQLITE_MISUSE; + return pPager->wal->methods.xReadFrameRaw(pPager->wal->pData, iFrame, pPager->pageSize, pFrameOut); + }else{ + return SQLITE_ERROR; + } +} + + int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + #ifdef SQLITE_ENABLE_SETLK_TIMEOUT /* ** If pager pPager is a wal-mode database not in exclusive locking mode, @@ -67599,9 +67635,10 @@ static int sqlite3WalClose( if( pWal->exclusiveMode==WAL_NORMAL_MODE ){ pWal->exclusiveMode = WAL_EXCLUSIVE_MODE; } - rc = sqlite3WalCheckpoint(pWal, db, - SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL - ); + rc = SQLITE_ERROR; + //rc = sqlite3WalCheckpoint(pWal, db, + // SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL + //); if( rc==SQLITE_OK ){ int bPersist = -1; sqlite3OsFileControlHint( @@ -68729,6 +68766,28 @@ static int sqlite3WalReadFrame( return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); } +/* +** Read the contents of frame iRead from the wal file into buffer pOut +** (which is nOut bytes in size). Return SQLITE_OK if successful, or an +** error code otherwise. +*/ +static int sqlite3WalReadFrameRaw( + Wal *pWal, /* WAL handle */ + u32 iRead, /* Frame to read */ + int nOut, /* Size of buffer pOut in bytes */ + u8 *pOut /* Buffer to write page data to */ +){ + int sz; + i64 iOffset; + sz = pWal->hdr.szPage; + sz = (sz&0xfe00) + ((sz&0x0001)<<16); + testcase( sz<=32768 ); + testcase( sz>=65536 ); + iOffset = walFrameOffset(iRead, sz); + /* testcase( IS_BIG_INT(iOffset) ); // requires a 4GiB WAL */ + return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); +} + /* ** Return the size of the database in pages (or zero, if unknown). */ @@ -69838,6 +69897,7 @@ static int sqlite3WalOpen( out->methods.xEndReadTransaction = (void (*)(wal_impl *))sqlite3WalEndReadTransaction; out->methods.xFindFrame = (int (*)(wal_impl *, unsigned int, unsigned int *))sqlite3WalFindFrame; out->methods.xReadFrame = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrame; + out->methods.xReadFrameRaw = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrameRaw; out->methods.xDbsize = (unsigned int (*)(wal_impl *))sqlite3WalDbsize; out->methods.xBeginWriteTransaction = (int (*)(wal_impl *))sqlite3WalBeginWriteTransaction; out->methods.xEndWriteTransaction = (int (*)(wal_impl *))sqlite3WalEndWriteTransaction; @@ -182863,6 +182923,62 @@ void *libsql_close_hook( return pRet; } +/* +** Return the number of frames in the WAL of the given database. +*/ +int libsql_wal_frame_count( + sqlite3* db, + unsigned int *pnFrame +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + *pnFrame = 0; + return SQLITE_OK; +#else +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + *pnFrame = sqlite3PagerWalFrameCount(pPager); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + +int libsql_wal_get_frame( + sqlite3* db, + unsigned int iFrame, + void *pBuf, + unsigned int nBuf +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + UNUSED_PARAMETER(iFrame); + UNUSED_PARAMETER(nBuf); + UNUSED_PARAMETER(pBuf); + return SQLITE_OK; +#else + +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + rc = sqlite3PagerWalReadFrameRaw(pPager, iFrame, pBuf, nBuf); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + /* ** Register a function to be invoked prior to each autovacuum that ** determines the number of pages to vacuum. diff --git a/libsql-ffi/bundled/bindings/bindgen.rs b/libsql-ffi/bundled/bindings/bindgen.rs index e11d453281..3e8318f9be 100644 --- a/libsql-ffi/bundled/bindings/bindgen.rs +++ b/libsql-ffi/bundled/bindings/bindgen.rs @@ -938,7 +938,7 @@ extern "C" { extern "C" { pub fn sqlite3_vmprintf( arg1: *const ::std::os::raw::c_char, - arg2: *mut __va_list_tag, + arg2: va_list, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -954,7 +954,7 @@ extern "C" { arg1: ::std::os::raw::c_int, arg2: *mut ::std::os::raw::c_char, arg3: *const ::std::os::raw::c_char, - arg4: *mut __va_list_tag, + arg4: va_list, ) -> *mut ::std::os::raw::c_char; } extern "C" { @@ -2501,7 +2501,7 @@ extern "C" { pub fn sqlite3_str_vappendf( arg1: *mut sqlite3_str, zFormat: *const ::std::os::raw::c_char, - arg2: *mut __va_list_tag, + arg2: va_list, ); } extern "C" { @@ -2861,6 +2861,20 @@ extern "C" { arg: *mut ::std::os::raw::c_void, ) -> *mut ::std::os::raw::c_void; } +extern "C" { + pub fn libsql_wal_frame_count( + arg1: *mut sqlite3, + arg2: *mut ::std::os::raw::c_uint, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn libsql_wal_get_frame( + arg1: *mut sqlite3, + arg2: ::std::os::raw::c_uint, + arg3: *mut ::std::os::raw::c_void, + arg4: ::std::os::raw::c_uint, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn sqlite3_system_errno(arg1: *mut sqlite3) -> ::std::os::raw::c_int; } @@ -3269,6 +3283,14 @@ pub struct libsql_wal_methods { arg3: *mut ::std::os::raw::c_uchar, ) -> ::std::os::raw::c_int, >, + pub xReadFrameRaw: ::std::option::Option< + unsafe extern "C" fn( + pWal: *mut wal_impl, + arg1: ::std::os::raw::c_uint, + arg2: ::std::os::raw::c_int, + arg3: *mut ::std::os::raw::c_uchar, + ) -> ::std::os::raw::c_int, + >, pub xDbsize: ::std::option::Option ::std::os::raw::c_uint>, pub xBeginWriteTransaction: @@ -3504,12 +3526,4 @@ extern "C" { extern "C" { pub static sqlite3_wal_manager: libsql_wal_manager; } -pub type __builtin_va_list = [__va_list_tag; 1usize]; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __va_list_tag { - pub gp_offset: ::std::os::raw::c_uint, - pub fp_offset: ::std::os::raw::c_uint, - pub overflow_arg_area: *mut ::std::os::raw::c_void, - pub reg_save_area: *mut ::std::os::raw::c_void, -} +pub type __builtin_va_list = *mut ::std::os::raw::c_char; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index c22f35046f..a282adb4b1 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -10935,6 +10935,10 @@ SQLITE_API int sqlite3_preupdate_blobwrite(sqlite3 *); */ SQLITE_API void *libsql_close_hook(sqlite3 *db, void (*xClose)(void *pCtx, sqlite3 *db), void *arg); +SQLITE_API int libsql_wal_frame_count(sqlite3*, unsigned int*); + +SQLITE_API int libsql_wal_get_frame(sqlite3*, unsigned int, void*, unsigned int); + /* ** CAPI3REF: Low-level system error code ** METHOD: sqlite3 @@ -13960,6 +13964,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); @@ -16373,6 +16378,9 @@ SQLITE_PRIVATE int sqlite3PagerReadFileheader(Pager*, int, unsigned char*); SQLITE_PRIVATE void sqlite3PagerSetBusyHandler(Pager*, int(*)(void *), void *); SQLITE_PRIVATE int sqlite3PagerSetPagesize(Pager*, u32*, int); SQLITE_PRIVATE Pgno sqlite3PagerMaxPageCount(Pager*, Pgno); +SQLITE_PRIVATE unsigned int sqlite3PagerWalFrameCount(Pager *); +SQLITE_PRIVATE int sqlite3PagerWalReadFrame(Pager *, unsigned int, void *, unsigned int); + SQLITE_PRIVATE void sqlite3PagerSetCachesize(Pager*, int); SQLITE_PRIVATE int sqlite3PagerSetSpillsize(Pager*, int); SQLITE_PRIVATE void sqlite3PagerSetMmapLimit(Pager *, sqlite3_int64); @@ -57268,6 +57276,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); @@ -65212,6 +65221,33 @@ SQLITE_PRIVATE int sqlite3PagerCloseWal(Pager *pPager, sqlite3 *db){ return rc; } +SQLITE_PRIVATE unsigned int sqlite3PagerWalFrameCount(Pager *pPager){ + if( pagerUseWal(pPager) ){ + // TODO: We are under sqlite3 mutex, but do we need something else? + struct sqlite3_wal* pWal = (void*) pPager->wal->pData; + return pWal->hdr.mxFrame; + }else{ + return 0; + } +} + +SQLITE_PRIVATE int sqlite3PagerWalReadFrameRaw( + Pager *pPager, + unsigned int iFrame, + void *pFrameOut, + unsigned int nFrameOutLen +){ + if( pagerUseWal(pPager) ){ + unsigned int nFrameLen = 24+pPager->pageSize; + if( nFrameOutLen!=nFrameLen ) return SQLITE_MISUSE; + return pPager->wal->methods.xReadFrameRaw(pPager->wal->pData, iFrame, pPager->pageSize, pFrameOut); + }else{ + return SQLITE_ERROR; + } +} + + int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + #ifdef SQLITE_ENABLE_SETLK_TIMEOUT /* ** If pager pPager is a wal-mode database not in exclusive locking mode, @@ -67599,9 +67635,10 @@ static int sqlite3WalClose( if( pWal->exclusiveMode==WAL_NORMAL_MODE ){ pWal->exclusiveMode = WAL_EXCLUSIVE_MODE; } - rc = sqlite3WalCheckpoint(pWal, db, - SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL - ); + rc = SQLITE_ERROR; + //rc = sqlite3WalCheckpoint(pWal, db, + // SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL + //); if( rc==SQLITE_OK ){ int bPersist = -1; sqlite3OsFileControlHint( @@ -68729,6 +68766,28 @@ static int sqlite3WalReadFrame( return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); } +/* +** Read the contents of frame iRead from the wal file into buffer pOut +** (which is nOut bytes in size). Return SQLITE_OK if successful, or an +** error code otherwise. +*/ +static int sqlite3WalReadFrameRaw( + Wal *pWal, /* WAL handle */ + u32 iRead, /* Frame to read */ + int nOut, /* Size of buffer pOut in bytes */ + u8 *pOut /* Buffer to write page data to */ +){ + int sz; + i64 iOffset; + sz = pWal->hdr.szPage; + sz = (sz&0xfe00) + ((sz&0x0001)<<16); + testcase( sz<=32768 ); + testcase( sz>=65536 ); + iOffset = walFrameOffset(iRead, sz); + /* testcase( IS_BIG_INT(iOffset) ); // requires a 4GiB WAL */ + return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); +} + /* ** Return the size of the database in pages (or zero, if unknown). */ @@ -69838,6 +69897,7 @@ static int sqlite3WalOpen( out->methods.xEndReadTransaction = (void (*)(wal_impl *))sqlite3WalEndReadTransaction; out->methods.xFindFrame = (int (*)(wal_impl *, unsigned int, unsigned int *))sqlite3WalFindFrame; out->methods.xReadFrame = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrame; + out->methods.xReadFrameRaw = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrameRaw; out->methods.xDbsize = (unsigned int (*)(wal_impl *))sqlite3WalDbsize; out->methods.xBeginWriteTransaction = (int (*)(wal_impl *))sqlite3WalBeginWriteTransaction; out->methods.xEndWriteTransaction = (int (*)(wal_impl *))sqlite3WalEndWriteTransaction; @@ -182863,6 +182923,62 @@ void *libsql_close_hook( return pRet; } +/* +** Return the number of frames in the WAL of the given database. +*/ +int libsql_wal_frame_count( + sqlite3* db, + unsigned int *pnFrame +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + *pnFrame = 0; + return SQLITE_OK; +#else +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + *pnFrame = sqlite3PagerWalFrameCount(pPager); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + +int libsql_wal_get_frame( + sqlite3* db, + unsigned int iFrame, + void *pBuf, + unsigned int nBuf +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + UNUSED_PARAMETER(iFrame); + UNUSED_PARAMETER(nBuf); + UNUSED_PARAMETER(pBuf); + return SQLITE_OK; +#else + +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + rc = sqlite3PagerWalReadFrameRaw(pPager, iFrame, pBuf, nBuf); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + /* ** Register a function to be invoked prior to each autovacuum that ** determines the number of pages to vacuum. diff --git a/libsql-ffi/bundled/src/sqlite3.h b/libsql-ffi/bundled/src/sqlite3.h index d526834332..a21d6ed206 100644 --- a/libsql-ffi/bundled/src/sqlite3.h +++ b/libsql-ffi/bundled/src/sqlite3.h @@ -10549,6 +10549,10 @@ SQLITE_API int sqlite3_preupdate_blobwrite(sqlite3 *); */ SQLITE_API void *libsql_close_hook(sqlite3 *db, void (*xClose)(void *pCtx, sqlite3 *db), void *arg); +SQLITE_API int libsql_wal_frame_count(sqlite3*, unsigned int*); + +SQLITE_API int libsql_wal_get_frame(sqlite3*, unsigned int, void*, unsigned int); + /* ** CAPI3REF: Low-level system error code ** METHOD: sqlite3 @@ -13574,6 +13578,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); diff --git a/libsql-replication/src/injector/injector_wal.rs b/libsql-replication/src/injector/injector_wal.rs index eb92941d96..34273053c8 100644 --- a/libsql-replication/src/injector/injector_wal.rs +++ b/libsql-replication/src/injector/injector_wal.rs @@ -110,6 +110,10 @@ impl Wal for InjectorWal { self.inner.read_frame(frame_no, buffer) } + fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()> { + self.inner.read_frame_raw(frame_no, buffer) + } + fn db_size(&self) -> u32 { self.inner.db_size() } diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 6763c02dfb..978e0b37d8 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -37,6 +37,7 @@ itertools = "0.10.5" jsonwebtoken = "9" libsql = { path = "../libsql/", optional = true } libsql_replication = { path = "../libsql-replication" } +libsql_sync = { path = "../libsql-sync" } libsql-wal = { path = "../libsql-wal/" } libsql-storage = { path = "../libsql-storage", optional = true } metrics = "0.21.1" diff --git a/libsql-server/src/rpc/mod.rs b/libsql-server/src/rpc/mod.rs index 6359556518..bd012c37f9 100644 --- a/libsql-server/src/rpc/mod.rs +++ b/libsql-server/src/rpc/mod.rs @@ -24,6 +24,7 @@ pub mod replica_proxy; pub mod replication_log; pub mod replication_log_proxy; pub mod streaming_exec; +pub mod sync; pub async fn run_rpc_server( proxy_service: ProxyService, diff --git a/libsql-server/src/rpc/sync.rs b/libsql-server/src/rpc/sync.rs new file mode 100644 index 0000000000..c4c6f1745f --- /dev/null +++ b/libsql-server/src/rpc/sync.rs @@ -0,0 +1,31 @@ +use futures::stream::BoxStream; +use libsql_sync::sync::rpc::wal_sync_server::WalSync; + +pub struct SyncService {} + +#[tonic::async_trait] +impl WalSync for SyncService { + type FetchDatabaseStream = + BoxStream<'static, Result>; + + async fn fetch_database( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + unimplemented!() + } + + async fn pull_wal( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + unimplemented!() + } + + async fn push_wal( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + unimplemented!() + } +} diff --git a/libsql-sqlite3/src/main.c b/libsql-sqlite3/src/main.c index 5946501731..c6eb5d6628 100644 --- a/libsql-sqlite3/src/main.c +++ b/libsql-sqlite3/src/main.c @@ -2425,6 +2425,62 @@ void *libsql_close_hook( return pRet; } +/* +** Return the number of frames in the WAL of the given database. +*/ +int libsql_wal_frame_count( + sqlite3* db, + unsigned int *pnFrame +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + *pnFrame = 0; + return SQLITE_OK; +#else +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + *pnFrame = sqlite3PagerWalFrameCount(pPager); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + +int libsql_wal_get_frame( + sqlite3* db, + unsigned int iFrame, + void *pBuf, + unsigned int nBuf +){ + int rc = SQLITE_OK; + Pager *pPager; + +#ifdef SQLITE_OMIT_WAL + UNUSED_PARAMETER(iFrame); + UNUSED_PARAMETER(nBuf); + UNUSED_PARAMETER(pBuf); + return SQLITE_OK; +#else + +#ifdef SQLITE_ENABLE_API_ARMOR + if( !sqlite3SafetyCheckOk(db) ) return SQLITE_MISUSE_BKPT; +#endif + + sqlite3_mutex_enter(db->mutex); + pPager = sqlite3BtreePager(db->aDb[0].pBt); + rc = sqlite3PagerWalReadFrameRaw(pPager, iFrame, pBuf, nBuf); + sqlite3_mutex_leave(db->mutex); + + return rc; +#endif +} + /* ** Register a function to be invoked prior to each autovacuum that ** determines the number of pages to vacuum. diff --git a/libsql-sqlite3/src/pager.c b/libsql-sqlite3/src/pager.c index 8b8348919c..a846967482 100644 --- a/libsql-sqlite3/src/pager.c +++ b/libsql-sqlite3/src/pager.c @@ -7771,6 +7771,33 @@ int sqlite3PagerCloseWal(Pager *pPager, sqlite3 *db){ return rc; } +unsigned int sqlite3PagerWalFrameCount(Pager *pPager){ + if( pagerUseWal(pPager) ){ + // TODO: We are under sqlite3 mutex, but do we need something else? + struct sqlite3_wal* pWal = (void*) pPager->wal->pData; + return pWal->hdr.mxFrame; + }else{ + return 0; + } +} + +int sqlite3PagerWalReadFrameRaw( + Pager *pPager, + unsigned int iFrame, + void *pFrameOut, + unsigned int nFrameOutLen +){ + if( pagerUseWal(pPager) ){ + unsigned int nFrameLen = 24+pPager->pageSize; + if( nFrameOutLen!=nFrameLen ) return SQLITE_MISUSE; + return pPager->wal->methods.xReadFrameRaw(pPager->wal->pData, iFrame, pPager->pageSize, pFrameOut); + }else{ + return SQLITE_ERROR; + } +} + + int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + #ifdef SQLITE_ENABLE_SETLK_TIMEOUT /* ** If pager pPager is a wal-mode database not in exclusive locking mode, diff --git a/libsql-sqlite3/src/pager.h b/libsql-sqlite3/src/pager.h index 875b1fc17d..3991d8975b 100644 --- a/libsql-sqlite3/src/pager.h +++ b/libsql-sqlite3/src/pager.h @@ -133,6 +133,9 @@ int sqlite3PagerReadFileheader(Pager*, int, unsigned char*); void sqlite3PagerSetBusyHandler(Pager*, int(*)(void *), void *); int sqlite3PagerSetPagesize(Pager*, u32*, int); Pgno sqlite3PagerMaxPageCount(Pager*, Pgno); +unsigned int sqlite3PagerWalFrameCount(Pager *); +int sqlite3PagerWalReadFrame(Pager *, unsigned int, void *, unsigned int); + void sqlite3PagerSetCachesize(Pager*, int); int sqlite3PagerSetSpillsize(Pager*, int); void sqlite3PagerSetMmapLimit(Pager *, sqlite3_int64); diff --git a/libsql-sqlite3/src/sqlite.h.in b/libsql-sqlite3/src/sqlite.h.in index 7df3afc632..26a6390fed 100644 --- a/libsql-sqlite3/src/sqlite.h.in +++ b/libsql-sqlite3/src/sqlite.h.in @@ -10549,6 +10549,10 @@ int sqlite3_preupdate_blobwrite(sqlite3 *); */ void *libsql_close_hook(sqlite3 *db, void (*xClose)(void *pCtx, sqlite3 *db), void *arg); +int libsql_wal_frame_count(sqlite3*, unsigned int*); + +int libsql_wal_get_frame(sqlite3*, unsigned int, void*, unsigned int); + /* ** CAPI3REF: Low-level system error code ** METHOD: sqlite3 diff --git a/libsql-sqlite3/src/wal.c b/libsql-sqlite3/src/wal.c index 0f6f699d5c..734b8f3edc 100644 --- a/libsql-sqlite3/src/wal.c +++ b/libsql-sqlite3/src/wal.c @@ -2256,9 +2256,10 @@ static int sqlite3WalClose( if( pWal->exclusiveMode==WAL_NORMAL_MODE ){ pWal->exclusiveMode = WAL_EXCLUSIVE_MODE; } - rc = sqlite3WalCheckpoint(pWal, db, - SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL - ); + rc = SQLITE_ERROR; + //rc = sqlite3WalCheckpoint(pWal, db, + // SQLITE_CHECKPOINT_PASSIVE, 0, 0, sync_flags, nBuf, zBuf, 0, 0, NULL, NULL + //); if( rc==SQLITE_OK ){ int bPersist = -1; sqlite3OsFileControlHint( @@ -3386,6 +3387,28 @@ static int sqlite3WalReadFrame( return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); } +/* +** Read the contents of frame iRead from the wal file into buffer pOut +** (which is nOut bytes in size). Return SQLITE_OK if successful, or an +** error code otherwise. +*/ +static int sqlite3WalReadFrameRaw( + Wal *pWal, /* WAL handle */ + u32 iRead, /* Frame to read */ + int nOut, /* Size of buffer pOut in bytes */ + u8 *pOut /* Buffer to write page data to */ +){ + int sz; + i64 iOffset; + sz = pWal->hdr.szPage; + sz = (sz&0xfe00) + ((sz&0x0001)<<16); + testcase( sz<=32768 ); + testcase( sz>=65536 ); + iOffset = walFrameOffset(iRead, sz); + /* testcase( IS_BIG_INT(iOffset) ); // requires a 4GiB WAL */ + return sqlite3OsRead(pWal->pWalFd, pOut, (nOut>sz ? sz : nOut), iOffset); +} + /* ** Return the size of the database in pages (or zero, if unknown). */ @@ -4495,6 +4518,7 @@ static int sqlite3WalOpen( out->methods.xEndReadTransaction = (void (*)(wal_impl *))sqlite3WalEndReadTransaction; out->methods.xFindFrame = (int (*)(wal_impl *, unsigned int, unsigned int *))sqlite3WalFindFrame; out->methods.xReadFrame = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrame; + out->methods.xReadFrameRaw = (int (*)(wal_impl *, unsigned int, int, unsigned char *))sqlite3WalReadFrameRaw; out->methods.xDbsize = (unsigned int (*)(wal_impl *))sqlite3WalDbsize; out->methods.xBeginWriteTransaction = (int (*)(wal_impl *))sqlite3WalBeginWriteTransaction; out->methods.xEndWriteTransaction = (int (*)(wal_impl *))sqlite3WalEndWriteTransaction; diff --git a/libsql-sqlite3/src/wal.h b/libsql-sqlite3/src/wal.h index 1939485002..a443a89b3a 100644 --- a/libsql-sqlite3/src/wal.h +++ b/libsql-sqlite3/src/wal.h @@ -56,6 +56,7 @@ typedef struct libsql_wal_methods { /* Read a page from the write-ahead log, if it is present. */ int (*xFindFrame)(wal_impl* pWal, unsigned int, unsigned int *); int (*xReadFrame)(wal_impl* pWal, unsigned int, int, unsigned char *); + int (*xReadFrameRaw)(wal_impl* pWal, unsigned int, int, unsigned char *); /* If the WAL is not empty, return the size of the database. */ unsigned int (*xDbsize)(wal_impl* pWal); diff --git a/libsql-sync/Cargo.toml b/libsql-sync/Cargo.toml new file mode 100644 index 0000000000..0989797bc7 --- /dev/null +++ b/libsql-sync/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "libsql_sync" +version = "0.4.0" +edition = "2021" +description = "libSQL WAL sync protocol" +repository = "https://github.com/tursodatabase/libsql" +license = "MIT" + +[dependencies] +libsql-sys = { version = "0.6", path = "../libsql-sys" } +tonic = { version = "0.11", features = ["tls"] } +prost = "0.12" + +[dev-dependencies] +prost-build = "0.12" +tonic-build = "0.11" diff --git a/libsql-sync/proto/walsync.proto b/libsql-sync/proto/walsync.proto new file mode 100644 index 0000000000..a3505c37ff --- /dev/null +++ b/libsql-sync/proto/walsync.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package walsync; + +service WALSync { + // Fetch the database file + rpc FetchDatabase (FetchDatabaseRequest) returns (stream DatabaseChunk) {} + + // Pull the WAL from the server + rpc PullWAL (PullWALRequest) returns (PullWALResponse) {} + + // Push local changes to the server and potentially trigger checkpointing + rpc PushWAL (PushWALRequest) returns (PushWALResponse) {} +} + +message FetchDatabaseRequest { + string client_id = 1; +} + +message DatabaseChunk { + bytes data = 1; + uint64 offset = 2; + bool is_last_chunk = 3; +} + +message PullWALRequest { + string client_id = 1; + uint64 client_last_checkpoint_frame_id = 2; +} + +message PullWALResponse { + repeated WALFrame wal = 1; + uint64 server_last_checkpoint_frame_id = 2; + bool need_full_db_sync = 3; +} + +message WALFrame { + uint64 frame_id = 1; + bytes data = 2; +} + +message PushWALRequest { + string client_id = 1; + uint64 base_frame_id = 2; + repeated WALFrame new_frames = 3; + uint64 last_checkpoint_frame_id = 4; + bool request_checkpoint = 5; +} + +message PushWALResponse { + enum Status { + SUCCESS = 0; + CONFLICT = 1; + ERROR = 2; + NEED_FULL_SYNC = 3; + } + Status status = 1; + string message = 2; + repeated WALFrame server_wal = 3; + uint64 server_last_checkpoint_frame_id = 4; + bool perform_checkpoint = 5; + uint64 checkpoint_frame_id = 6; +} diff --git a/libsql-sync/src/generated/walsync.rs b/libsql-sync/src/generated/walsync.rs new file mode 100644 index 0000000000..ab02d77148 --- /dev/null +++ b/libsql-sync/src/generated/walsync.rs @@ -0,0 +1,564 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchDatabaseRequest { + #[prost(string, tag = "1")] + pub client_id: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DatabaseChunk { + #[prost(bytes = "bytes", tag = "1")] + pub data: ::prost::bytes::Bytes, + #[prost(uint64, tag = "2")] + pub offset: u64, + #[prost(bool, tag = "3")] + pub is_last_chunk: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PullWalRequest { + #[prost(string, tag = "1")] + pub client_id: ::prost::alloc::string::String, + #[prost(uint64, tag = "2")] + pub client_last_checkpoint_frame_id: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PullWalResponse { + #[prost(message, repeated, tag = "1")] + pub wal: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "2")] + pub server_last_checkpoint_frame_id: u64, + #[prost(bool, tag = "3")] + pub need_full_db_sync: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WalFrame { + #[prost(uint64, tag = "1")] + pub frame_id: u64, + #[prost(bytes = "bytes", tag = "2")] + pub data: ::prost::bytes::Bytes, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PushWalRequest { + #[prost(string, tag = "1")] + pub client_id: ::prost::alloc::string::String, + #[prost(uint64, tag = "2")] + pub base_frame_id: u64, + #[prost(message, repeated, tag = "3")] + pub new_frames: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "4")] + pub last_checkpoint_frame_id: u64, + #[prost(bool, tag = "5")] + pub request_checkpoint: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PushWalResponse { + #[prost(enumeration = "push_wal_response::Status", tag = "1")] + pub status: i32, + #[prost(string, tag = "2")] + pub message: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub server_wal: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "4")] + pub server_last_checkpoint_frame_id: u64, + #[prost(bool, tag = "5")] + pub perform_checkpoint: bool, + #[prost(uint64, tag = "6")] + pub checkpoint_frame_id: u64, +} +/// Nested message and enum types in `PushWALResponse`. +pub mod push_wal_response { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum Status { + Success = 0, + Conflict = 1, + Error = 2, + NeedFullSync = 3, + } + impl Status { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Status::Success => "SUCCESS", + Status::Conflict => "CONFLICT", + Status::Error => "ERROR", + Status::NeedFullSync => "NEED_FULL_SYNC", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SUCCESS" => Some(Self::Success), + "CONFLICT" => Some(Self::Conflict), + "ERROR" => Some(Self::Error), + "NEED_FULL_SYNC" => Some(Self::NeedFullSync), + _ => None, + } + } + } +} +/// Generated client implementations. +pub mod wal_sync_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct WalSyncClient { + inner: tonic::client::Grpc, + } + impl WalSyncClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl WalSyncClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> WalSyncClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + WalSyncClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// Fetch the database file + pub async fn fetch_database( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/walsync.WALSync/FetchDatabase", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("walsync.WALSync", "FetchDatabase")); + self.inner.server_streaming(req, path, codec).await + } + /// Pull the WAL from the server + pub async fn pull_wal( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/walsync.WALSync/PullWAL"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new("walsync.WALSync", "PullWAL")); + self.inner.unary(req, path, codec).await + } + /// Push local changes to the server and potentially trigger checkpointing + pub async fn push_wal( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/walsync.WALSync/PushWAL"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new("walsync.WALSync", "PushWAL")); + self.inner.unary(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod wal_sync_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with WalSyncServer. + #[async_trait] + pub trait WalSync: Send + Sync + 'static { + /// Server streaming response type for the FetchDatabase method. + type FetchDatabaseStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + Send + + 'static; + /// Fetch the database file + async fn fetch_database( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + /// Pull the WAL from the server + async fn pull_wal( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Push local changes to the server and potentially trigger checkpointing + async fn push_wal( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct WalSyncServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + struct _Inner(Arc); + impl WalSyncServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for WalSyncServer + where + T: WalSync, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/walsync.WALSync/FetchDatabase" => { + #[allow(non_camel_case_types)] + struct FetchDatabaseSvc(pub Arc); + impl< + T: WalSync, + > tonic::server::ServerStreamingService + for FetchDatabaseSvc { + type Response = super::DatabaseChunk; + type ResponseStream = T::FetchDatabaseStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::fetch_database(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = FetchDatabaseSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/walsync.WALSync/PullWAL" => { + #[allow(non_camel_case_types)] + struct PullWALSvc(pub Arc); + impl tonic::server::UnaryService + for PullWALSvc { + type Response = super::PullWalResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::pull_wal(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PullWALSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/walsync.WALSync/PushWAL" => { + #[allow(non_camel_case_types)] + struct PushWALSvc(pub Arc); + impl tonic::server::UnaryService + for PushWALSvc { + type Response = super::PushWalResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::push_wal(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PushWALSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for WalSyncServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService for WalSyncServer { + const NAME: &'static str = "walsync.WALSync"; + } +} diff --git a/libsql-sync/src/lib.rs b/libsql-sync/src/lib.rs new file mode 100644 index 0000000000..d086d5bd67 --- /dev/null +++ b/libsql-sync/src/lib.rs @@ -0,0 +1 @@ +pub mod sync; diff --git a/libsql-sync/src/sync.rs b/libsql-sync/src/sync.rs new file mode 100644 index 0000000000..a810a93dba --- /dev/null +++ b/libsql-sync/src/sync.rs @@ -0,0 +1,20 @@ +pub mod rpc { + #![allow(clippy::all)] + include!("generated/walsync.rs"); +} + +pub struct SyncContext { + durable_frame_num: u32, +} + +impl SyncContext { + pub fn new() -> Self { + Self { + durable_frame_num: 0, + } + } + + pub fn durable_frame_num(&self) -> u32 { + 0 + } +} diff --git a/libsql-sync/tests/bootstrap.rs b/libsql-sync/tests/bootstrap.rs new file mode 100644 index 0000000000..694aab0484 --- /dev/null +++ b/libsql-sync/tests/bootstrap.rs @@ -0,0 +1,33 @@ +use std::{path::PathBuf, process::Command}; + +#[test] +fn bootstrap() { + let iface_files = &["proto/walsync.proto"]; + let dirs = &["proto"]; + + let out_dir = PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) + .join("src") + .join("generated"); + + let mut config = prost_build::Config::new(); + config.bytes([".walsync"]); + + tonic_build::configure() + .build_client(true) + .build_server(true) + .build_transport(true) + .out_dir(&out_dir) + .type_attribute(".proxy", "#[derive(serde::Serialize, serde::Deserialize)]") + .compile_with_config(config, iface_files, dirs) + .unwrap(); + + let status = Command::new("git") + .arg("diff") + .arg("--exit-code") + .arg("--") + .arg(&out_dir) + .status() + .unwrap(); + + assert!(status.success(), "You should commit the protobuf files"); +} diff --git a/libsql-sys/src/wal/either.rs b/libsql-sys/src/wal/either.rs index 1f9a67609e..2239539211 100644 --- a/libsql-sys/src/wal/either.rs +++ b/libsql-sys/src/wal/either.rs @@ -39,6 +39,12 @@ macro_rules! create_either { } } + fn read_frame_raw(&mut self, frame_no: std::num::NonZeroU32, buffer: &mut [u8]) -> super::Result<()> { + match self { + $( $name::$t(inner) => inner.read_frame_raw(frame_no, buffer) ),* + } + } + fn db_size(&self) -> u32 { match self { $( $name::$t(inner) => inner.db_size() ),* diff --git a/libsql-sys/src/wal/ffi.rs b/libsql-sys/src/wal/ffi.rs index 2dedcd2567..fcb22c35b4 100644 --- a/libsql-sys/src/wal/ffi.rs +++ b/libsql-sys/src/wal/ffi.rs @@ -23,6 +23,7 @@ pub(crate) fn construct_libsql_wal(wal: *mut W) -> libsql_wal { xEndReadTransaction: Some(end_read_transaction::), xFindFrame: Some(find_frame::), xReadFrame: Some(read_frame::), + xReadFrameRaw: Some(read_frame_raw::), xDbsize: Some(db_size::), xBeginWriteTransaction: Some(begin_write_transaction::), xEndWriteTransaction: Some(end_write_transaction::), @@ -210,6 +211,23 @@ pub unsafe extern "C" fn read_frame( } } +pub unsafe extern "C" fn read_frame_raw( + wal: *mut wal_impl, + frame: u32, + n_out: c_int, + p_out: *mut u8, +) -> i32 { + let this = &mut (*(wal as *mut T)); + let buffer = std::slice::from_raw_parts_mut(p_out, n_out as usize); + match this.read_frame_raw( + NonZeroU32::new(frame).expect("invalid frame number"), + buffer, + ) { + Ok(_) => SQLITE_OK, + Err(code) => code.extended_code, + } +} + pub unsafe extern "C" fn db_size(wal: *mut wal_impl) -> u32 { let this = &mut (*(wal as *mut T)); this.db_size() diff --git a/libsql-sys/src/wal/mod.rs b/libsql-sys/src/wal/mod.rs index 3577cb44e2..57e11401a5 100644 --- a/libsql-sys/src/wal/mod.rs +++ b/libsql-sys/src/wal/mod.rs @@ -186,6 +186,7 @@ pub trait Wal { fn find_frame(&mut self, page_no: NonZeroU32) -> Result>; /// reads frame `frame_no` into buffer. fn read_frame(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()>; + fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()>; fn db_size(&self) -> u32; diff --git a/libsql-sys/src/wal/sqlite3_wal.rs b/libsql-sys/src/wal/sqlite3_wal.rs index 548434a2a2..38fee5220c 100644 --- a/libsql-sys/src/wal/sqlite3_wal.rs +++ b/libsql-sys/src/wal/sqlite3_wal.rs @@ -203,6 +203,22 @@ impl Wal for Sqlite3Wal { } } + fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()> { + let rc = unsafe { + (self.inner.methods.xReadFrameRaw.unwrap())( + self.inner.pData, + frame_no.into(), + buffer.len() as _, + buffer.as_mut_ptr(), + ) + }; + if rc != 0 { + Err(Error::new(rc)) + } else { + Ok(()) + } + } + fn db_size(&self) -> u32 { unsafe { (self.inner.methods.xDbsize.unwrap())(self.inner.pData) } } diff --git a/libsql-sys/src/wal/wrapper.rs b/libsql-sys/src/wal/wrapper.rs index 713ba0347b..f141a0a878 100644 --- a/libsql-sys/src/wal/wrapper.rs +++ b/libsql-sys/src/wal/wrapper.rs @@ -61,6 +61,10 @@ impl, W: Wal> Wal for WalRef { unsafe { (*self.wrapper).read_frame(&mut *self.wrapped, frame_no, buffer) } } + fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> super::Result<()> { + unsafe { (*self.wrapper).read_frame_raw(&mut *self.wrapped, frame_no, buffer) } + } + fn db_size(&self) -> u32 { unsafe { (*self.wrapper).db_size(&*self.wrapped) } } @@ -234,6 +238,10 @@ where self.wrapper.read_frame(&mut self.wrapped, frame_no, buffer) } + fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> super::Result<()> { + self.wrapper.read_frame_raw(&mut self.wrapped, frame_no, buffer) + } + fn db_size(&self) -> u32 { self.wrapper.db_size(&self.wrapped) } @@ -355,6 +363,15 @@ pub trait WrapWal { wrapped.read_frame(frame_no, buffer) } + fn read_frame_raw( + &mut self, + wrapped: &mut W, + frame_no: NonZeroU32, + buffer: &mut [u8], + ) -> super::Result<()> { + wrapped.read_frame(frame_no, buffer) + } + fn db_size(&self, wrapped: &W) -> u32 { wrapped.db_size() } diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index fa89cc68ad..ce56b94533 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -13,6 +13,7 @@ thiserror = "1.0.40" futures = { version = "0.3.28", optional = true } libsql-sys = { version = "0.6", path = "../libsql-sys", optional = true } libsql-hrana = { version = "0.2", path = "../libsql-hrana", optional = true } +libsql_sync = { version = "0.4", path = "../libsql-sync", optional = true } tokio = { version = "1.29.1", features = ["sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } parking_lot = { version = "0.12.1", optional = true } @@ -91,6 +92,7 @@ replication = [ "dep:hyper-rustls", "dep:futures", "dep:libsql_replication", + "dep:libsql_sync", ] hrana = [ "parser", diff --git a/libsql/examples/offline_writes.rs b/libsql/examples/offline_writes.rs new file mode 100644 index 0000000000..2a1ef939b2 --- /dev/null +++ b/libsql/examples/offline_writes.rs @@ -0,0 +1,85 @@ +// Example of using a offline writes with libSQL. + +use libsql::{params, Builder}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // The local database path where the data will be stored. + let db_path = std::env::var("LIBSQL_DB_PATH") + .map_err(|_| { + eprintln!( + "Please set the LIBSQL_DB_PATH environment variable to set to local database path." + ) + }) + .unwrap(); + + // The remote sync URL to use. + let sync_url = std::env::var("LIBSQL_SYNC_URL") + .map_err(|_| { + eprintln!( + "Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL." + ) + }) + .unwrap(); + + let namespace = std::env::var("LIBSQL_NAMESPACE").ok(); + + // The authentication token to use. + let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); + + let db_builder = if let Some(ns) = namespace { + Builder::new_offline_replica(db_path, sync_url, auth_token).namespace(&ns) + } else { + Builder::new_offline_replica(db_path, sync_url, auth_token) + }; + + let db = match db_builder.build().await { + Ok(db) => db, + Err(error) => { + eprintln!("Error connecting to remote sync server: {}", error); + return; + } + }; + + let conn = db.connect().unwrap(); + + conn.execute( + r#" + CREATE TABLE IF NOT EXISTS guest_book_entries ( + text TEXT + )"#, + (), + ) + .await + .unwrap(); + + let mut input = String::new(); + println!("Please write your entry to the guestbook:"); + match std::io::stdin().read_line(&mut input) { + Ok(_) => { + println!("You entered: {}", input); + let params = params![input.as_str()]; + conn.execute("INSERT INTO guest_book_entries (text) VALUES (?)", params) + .await + .unwrap(); + } + Err(error) => { + eprintln!("Error reading input: {}", error); + } + } + let mut results = conn + .query("SELECT * FROM guest_book_entries", ()) + .await + .unwrap(); + println!("Guest book entries:"); + while let Some(row) = results.next().await.unwrap() { + let text: String = row.get(0).unwrap(); + println!(" {}", text); + } + + print!("Syncing database to remote..."); + db.sync().await.unwrap(); + println!(" done"); +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index e87def367d..d2508f1729 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -46,6 +46,8 @@ enum DbType { db: crate::local::Database, encryption_config: Option, }, + #[cfg(feature = "replication")] + Offline { db: crate::local::Database }, #[cfg(feature = "remote")] Remote { url: String, @@ -65,6 +67,8 @@ impl fmt::Debug for DbType { Self::File { .. } => write!(f, "File"), #[cfg(feature = "replication")] Self::Sync { .. } => write!(f, "Sync"), + #[cfg(feature = "replication")] + Self::Offline { .. } => write!(f, "Offline"), #[cfg(feature = "remote")] Self::Remote { .. } => write!(f, "Remote"), _ => write!(f, "no database type set"), @@ -324,10 +328,10 @@ cfg_replication! { /// Sync database from remote, and returns the committed frame_no after syncing, if /// applicable. pub async fn sync(&self) -> Result { - if let DbType::Sync { db, encryption_config: _ } = &self.db_type { - db.sync().await - } else { - Err(Error::SyncNotSupported(format!("{:?}", self.db_type))) + match &self.db_type { + DbType::Sync { db, encryption_config: _ } => db.sync().await, + DbType::Offline { db } => db.push().await, + _ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))), } } @@ -558,6 +562,17 @@ impl Database { Ok(Connection { conn }) } + #[cfg(feature = "replication")] + DbType::Offline { db } => { + use crate::local::impls::LibsqlConnection; + + let conn = db.connect()?; + + let conn = std::sync::Arc::new(LibsqlConnection { conn }); + + Ok(Connection { conn }) + } + #[cfg(feature = "remote")] DbType::Remote { url, diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 8749b6452b..495aa84b98 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -12,6 +12,8 @@ use super::DbType; /// it does no networking and does not connect to any remote database. /// - `new_remote_replica`/`RemoteReplica` creates an embedded replica database that will be able /// to sync from the remote url and delegate writes to the remote primary. +/// - `new_offline_replica`/`OfflineReplica` creates an embedded replica database that supports +/// offline writes. /// - `new_local_replica`/`LocalReplica` creates an embedded replica similar to the remote version /// except you must use `Database::sync_frames` to sync with the remote. This version also /// includes the ability to delegate writes to a remote primary. @@ -66,6 +68,30 @@ impl Builder<()> { } } + cfg_replication! { + /// Create a new offline embedded replica. + pub fn new_offline_replica( + path: impl AsRef, + url: String, + auth_token: String, + ) -> Builder { + Builder { + inner: OfflineReplica { + path: path.as_ref().to_path_buf(), + flags: crate::OpenFlags::default(), + remote: Remote { + url, + auth_token, + connector: None, + version: None, + }, + http_request_callback: None, + namespace: None + }, + } + } + } + /// Create a new local replica. pub fn new_local_replica(path: impl AsRef) -> Builder { Builder { @@ -170,6 +196,15 @@ cfg_replication! { namespace: Option, } + /// Remote replica configuration type in [`Builder`]. + pub struct OfflineReplica { + path: std::path::PathBuf, + flags: crate::OpenFlags, + remote: Remote, + http_request_callback: Option, + namespace: Option, + } + /// Local replica configuration type in [`Builder`]. pub struct LocalReplica { path: std::path::PathBuf, @@ -295,6 +330,90 @@ cfg_replication! { } } + impl Builder { + /// Provide a custom http connector that will be used to create http connections. + pub fn connector(mut self, connector: C) -> Builder + where + C: tower::Service + Send + Clone + Sync + 'static, + C::Response: crate::util::Socket, + C::Future: Send + 'static, + C::Error: Into>, + { + self.inner.remote = self.inner.remote.connector(connector); + self + } + + pub fn http_request_callback(mut self, f: F) -> Builder + where + F: Fn(&mut http::Request<()>) + Send + Sync + 'static + { + self.inner.http_request_callback = Some(std::sync::Arc::new(f)); + self + + } + + /// Set the namespace that will be communicated to remote replica in the http header. + pub fn namespace(mut self, namespace: impl Into) -> Builder + { + self.inner.namespace = Some(namespace.into()); + self + } + + #[doc(hidden)] + pub fn version(mut self, version: String) -> Builder { + self.inner.remote = self.inner.remote.version(version); + self + } + + /// Build the remote embedded replica database. + pub async fn build(self) -> Result { + let OfflineReplica { + path, + flags, + remote: + Remote { + url, + auth_token, + connector, + version, + }, + http_request_callback, + namespace + } = self.inner; + + let connector = if let Some(connector) = connector { + connector + } else { + let https = super::connector()?; + use tower::ServiceExt; + + let svc = https + .map_err(|e| e.into()) + .map_response(|s| Box::new(s) as Box); + + crate::util::ConnectorService::new(svc) + }; + + let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); + + let db = crate::local::Database::open_local_with_offline_writes( + connector, + path, + flags, + url, + auth_token, + version, + http_request_callback, + namespace, + ) + .await?; + + Ok(Database { + db_type: DbType::Offline { db }, + }) + } + } + impl Builder { /// Set [`OpenFlags`] for this database. pub fn flags(mut self, flags: crate::OpenFlags) -> Builder { diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index bcf48ff23c..09a7e5431d 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -8,7 +8,7 @@ use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction}; use crate::TransactionBehavior; -use libsql_sys::ffi; +use libsql_sys::{ffi, wal}; use std::{ffi::c_int, fmt, path::Path, sync::Arc}; /// A connection to a libSQL database. @@ -57,13 +57,20 @@ impl Connection { ))); } } - - Ok(Connection { + let conn = Connection { raw, drop_ref: Arc::new(()), #[cfg(feature = "replication")] writer: db.writer()?, - }) + }; + if let Some(_) = db.sync_ctx { + // We need to make sure database is in WAL mode with checkpointing + // disabled so that we can sync our changes back to a remote + // server. + conn.query("PRAGMA journal_mode = WAL", Params::None)?; + conn.query("PRAGMA wal_autocheckpoint = 0", Params::None)?; + } + Ok(conn) } /// Get a raw handle to the underlying libSQL connection diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 2892d809cc..abd01fa29f 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -9,13 +9,15 @@ cfg_replication!( use crate::replication::local_client::LocalClient; use crate::replication::remote_client::RemoteClient; use crate::replication::EmbeddedReplicator; - pub use crate::replication::Frames; + pub use crate::replication::{Replicated, Frames}; pub struct ReplicationContext { pub(crate) replicator: EmbeddedReplicator, client: Option, read_your_writes: bool, } + + use libsql_sync::sync::SyncContext; ); use crate::{database::OpenFlags, local::connection::Connection}; @@ -28,6 +30,8 @@ pub struct Database { pub flags: OpenFlags, #[cfg(feature = "replication")] pub replication_ctx: Option, + #[cfg(feature = "replication")] + pub sync_ctx: Option, } impl Database { @@ -122,6 +126,32 @@ impl Database { Ok(db) } + #[cfg(feature = "replication")] + #[doc(hidden)] + pub async fn open_local_with_offline_writes( + connector: crate::util::ConnectorService, + db_path: impl Into, + flags: OpenFlags, + endpoint: String, + auth_token: String, + version: Option, + http_request_callback: Option, + namespace: Option + + ) -> Result { + use std::path::PathBuf; + + let db_path = db_path.into(); + let mut db = Database::open(&db_path, flags)?; + + let path = PathBuf::from(db_path); + let client = LocalClient::new(&path) + .await + .map_err(|e| crate::Error::Replication(e.into()))?; + + Ok(db) + } + #[cfg(feature = "replication")] pub async fn open_local_sync( db_path: impl Into, @@ -228,6 +258,8 @@ impl Database { flags, #[cfg(feature = "replication")] replication_ctx: None, + #[cfg(feature = "replication")] + sync_ctx: Some(SyncContext::new()), } } @@ -313,6 +345,35 @@ impl Database { } } + #[cfg(feature = "replication")] + /// Push WAL frames to remote. + pub async fn push(&self) -> Result { + let conn = self.connect()?; + let conn = conn.handle(); + let mut max_frame_no: std::os::raw::c_uint = 0; + println!("Maximum frame: {}", max_frame_no); + unsafe { libsql_sys::ffi::libsql_wal_frame_count(conn, &mut max_frame_no) }; + let sync_ctx = self.sync_ctx.as_ref().unwrap(); + let start_frame_no = sync_ctx.durable_frame_num() + 1; + let end_frame_no = max_frame_no; + for frame_no in start_frame_no..end_frame_no { + const FRAME_SIZE: usize = 24+4096; // FIXME: make dynamic + let frame: [u8; FRAME_SIZE] = [0; FRAME_SIZE]; + let rc = unsafe { + libsql_sys::ffi::libsql_wal_get_frame(conn, frame_no, frame.as_ptr() as *mut _, FRAME_SIZE as u32) + }; + if rc != 0 { + println!("Failed to get frame: {}", rc); + } else { + println!("Pushing frame: {:?}", frame); + } + } + Ok(Replicated{ + frame_no: None, + frames_synced: 0, + }) + } + pub(crate) fn path(&self) -> &str { &self.db_path } diff --git a/libsql/src/replication/mod.rs b/libsql/src/replication/mod.rs index 69cc0b5db2..63f511a25e 100644 --- a/libsql/src/replication/mod.rs +++ b/libsql/src/replication/mod.rs @@ -35,8 +35,8 @@ pub(crate) mod remote_client; #[derive(Debug)] pub struct Replicated { - frame_no: Option, - frames_synced: usize, + pub(crate) frame_no: Option, + pub(crate) frames_synced: usize, } impl Replicated { diff --git a/tlaplus/walsync/walsync.tla b/tlaplus/walsync/walsync.tla new file mode 100644 index 0000000000..0653b280e1 --- /dev/null +++ b/tlaplus/walsync/walsync.tla @@ -0,0 +1,182 @@ +---- MODULE walsync ---- +EXTENDS Integers, Sequences, FiniteSets, TLC + +CONSTANTS Writer, Conflictor, MaxFrameID + +VARIABLES + clientDB, \* Client database state + clientWAL, \* Client WAL frames + clientCheckpoint,\* Client's last checkpoint frame ID + serverDB, \* Server database state + serverWAL, \* Server WAL frames + serverCheckpoint,\* Server's last checkpoint frame ID + messages \* Messages in transit + +vars == << clientDB, clientWAL, clientCheckpoint, + serverDB, serverWAL, serverCheckpoint, messages >> + +Clients == {Writer, Conflictor} + +Message == [type: {"FetchDatabase", "PullWAL", "PushWAL"}, + sender: Clients, + payload: [clientId: Clients, + baseFrameId: Nat, + frames: Seq(Nat), + lastCheckpointFrameId: Nat, + requestCheckpoint: BOOLEAN]] + +Response == [type: {"DatabaseChunk", "PullWALResponse", "PushWALResponse"}, + receiver: Clients, + payload: [status: {"SUCCESS", "CONFLICT", "ERROR", "NEED_FULL_SYNC"}, + frames: Seq(Nat), + serverLastCheckpointFrameId: Nat, + performCheckpoint: BOOLEAN, + checkpointFrameId: Nat]] + +TypeOK == + /\ clientDB \in [Clients -> Nat] + /\ clientWAL \in [Clients -> Seq(Nat)] + /\ clientCheckpoint \in [Clients -> Nat] + /\ serverDB \in Nat + /\ serverWAL \in Seq(Nat) + /\ serverCheckpoint \in Nat + /\ messages \subseteq (Message \union Response) + +Init == + /\ clientDB = [c \in Clients |-> 0] + /\ clientWAL = [c \in Clients |-> <<>>] + /\ clientCheckpoint = [c \in Clients |-> 0] + /\ serverDB = 0 + /\ serverWAL = <<>> + /\ serverCheckpoint = 0 + /\ messages = {} + +\* Helper function to get the last frame ID +LastFrameId(wal) == IF Len(wal) = 0 THEN 0 ELSE wal[Len(wal)] + +ClientWrite == + /\ LET newFrame == LastFrameId(clientWAL[Writer]) + 1 + IN /\ newFrame <= MaxFrameID + /\ clientWAL' = [clientWAL EXCEPT ![Writer] = Append(@, newFrame)] + /\ UNCHANGED << clientDB, clientCheckpoint, serverDB, serverWAL, serverCheckpoint, messages >> + +RequestFetchDatabase(c) == + /\ messages' = messages \union {[type |-> "FetchDatabase", sender |-> c, + payload |-> [clientId |-> c]]} + /\ UNCHANGED << clientDB, clientWAL, clientCheckpoint, serverDB, serverWAL, serverCheckpoint >> + +RespondFetchDatabase == + \E m \in messages : + /\ m.type = "FetchDatabase" + /\ LET response == [type |-> "DatabaseChunk", + receiver |-> m.sender, + payload |-> [status |-> "SUCCESS", + frames |-> <<>>, + serverLastCheckpointFrameId |-> serverCheckpoint, + performCheckpoint |-> FALSE, + checkpointFrameId |-> 0]] + IN /\ messages' = (messages \ {m}) \union {response} + /\ clientDB' = [clientDB EXCEPT ![m.sender] = serverDB] + /\ UNCHANGED << clientWAL, clientCheckpoint, serverDB, serverWAL, serverCheckpoint >> + +RequestPullWAL(c) == + /\ messages' = messages \union {[type |-> "PullWAL", sender |-> c, + payload |-> [clientId |-> c, + lastCheckpointFrameId |-> clientCheckpoint[c]]]} + /\ UNCHANGED << clientDB, clientWAL, clientCheckpoint, serverDB, serverWAL, serverCheckpoint >> + +RespondPullWAL == + \E m \in messages : + /\ m.type = "PullWAL" + /\ LET newFrames == SubSeq(serverWAL, m.payload.lastCheckpointFrameId + 1, Len(serverWAL)) + response == [type |-> "PullWALResponse", + receiver |-> m.sender, + payload |-> [status |-> "SUCCESS", + frames |-> newFrames, + serverLastCheckpointFrameId |-> serverCheckpoint, + performCheckpoint |-> FALSE, + checkpointFrameId |-> 0]] + IN /\ messages' = (messages \ {m}) \union {response} + /\ clientWAL' = [clientWAL EXCEPT ![m.sender] = @ \o newFrames] + /\ UNCHANGED << clientDB, clientCheckpoint, serverDB, serverWAL, serverCheckpoint >> + +RequestPushWAL(c) == + /\ LET request == [type |-> "PushWAL", + sender |-> c, + payload |-> [clientId |-> c, + baseFrameId |-> LastFrameId(clientWAL[c]), + frames |-> clientWAL[c], + lastCheckpointFrameId |-> clientCheckpoint[c], + requestCheckpoint |-> TRUE]] + IN messages' = messages \union {request} + /\ UNCHANGED << clientDB, clientWAL, clientCheckpoint, serverDB, serverWAL, serverCheckpoint >> + +RespondPushWAL == + \E m \in messages : + /\ m.type = "PushWAL" + /\ LET doCheckpoint == \/ Len(serverWAL) > 2 * Len(m.payload.frames) + \/ m.payload.requestCheckpoint + newServerWAL == IF /\ m.sender = Writer + /\ m.payload.baseFrameId = LastFrameId(serverWAL) + THEN serverWAL \o m.payload.frames + ELSE serverWAL + newCheckpointId == IF doCheckpoint THEN LastFrameId(newServerWAL) ELSE serverCheckpoint + response == [type |-> "PushWALResponse", + receiver |-> m.sender, + payload |-> [status |-> IF /\ m.sender = Writer + /\ m.payload.baseFrameId = LastFrameId(serverWAL) + THEN "SUCCESS" ELSE "CONFLICT", + frames |-> IF /\ m.sender = Writer + /\ m.payload.baseFrameId = LastFrameId(serverWAL) + THEN <<>> ELSE serverWAL, + serverLastCheckpointFrameId |-> newCheckpointId, + performCheckpoint |-> doCheckpoint, + checkpointFrameId |-> newCheckpointId]] + IN /\ messages' = (messages \ {m}) \union {response} + /\ serverWAL' = newServerWAL + /\ serverCheckpoint' = newCheckpointId + /\ IF doCheckpoint + THEN serverDB' = LastFrameId(newServerWAL) + ELSE UNCHANGED serverDB + /\ UNCHANGED << clientDB, clientWAL, clientCheckpoint >> + +HandlePushWALResponse == + \E m \in messages : + /\ m.type = "PushWALResponse" + /\ LET client == m.receiver + IN /\ IF m.payload.status = "SUCCESS" + THEN /\ IF m.payload.performCheckpoint + THEN /\ clientCheckpoint' = [clientCheckpoint EXCEPT ![client] = m.payload.checkpointFrameId] + /\ clientDB' = [clientDB EXCEPT ![client] = LastFrameId(clientWAL[client])] + /\ clientWAL' = [clientWAL EXCEPT ![client] = SubSeq(@, m.payload.checkpointFrameId + 1, Len(@))] + ELSE UNCHANGED << clientCheckpoint, clientDB, clientWAL >> + ELSE /\ clientWAL' = [clientWAL EXCEPT ![client] = m.payload.frames] + /\ UNCHANGED << clientCheckpoint, clientDB >> + /\ messages' = messages \ {m} + /\ UNCHANGED << serverDB, serverWAL, serverCheckpoint >> + +Next == + \/ ClientWrite + \/ \E c \in Clients : RequestFetchDatabase(c) + \/ RespondFetchDatabase + \/ \E c \in Clients : RequestPullWAL(c) + \/ RespondPullWAL + \/ \E c \in Clients : RequestPushWAL(c) + \/ RespondPushWAL + \/ HandlePushWALResponse + +Spec == Init /\ [][Next]_vars + +\* Invariant: Writer's WAL and remote WAL match +WriterWALMatchesServer == + \A i \in 1..Len(clientWAL[Writer]) : + i <= Len(serverWAL) => clientWAL[Writer][i] = serverWAL[i] + +\* Invariant: Remote WAL never has conflicting writes +NoConflictingWrites == + \A i, j \in 1..Len(serverWAL) : + i # j => serverWAL[i] # serverWAL[j] + +THEOREM Spec => [](TypeOK /\ WriterWALMatchesServer /\ NoConflictingWrites) + +==== \ No newline at end of file