diff --git a/Examples/Examples.xcodeproj/project.pbxproj b/Examples/Examples.xcodeproj/project.pbxproj index dc37568..01be9b6 100644 --- a/Examples/Examples.xcodeproj/project.pbxproj +++ b/Examples/Examples.xcodeproj/project.pbxproj @@ -9,6 +9,8 @@ /* Begin PBXBuildFile section */ C441694F2F38F3B100051412 /* SQLiteUndo in Frameworks */ = {isa = PBXBuildFile; productRef = C441694E2F38F3B100051412 /* SQLiteUndo */; }; C44169512F38F3B100051412 /* SQLiteUndoTCA in Frameworks */ = {isa = PBXBuildFile; productRef = C44169502F38F3B100051412 /* SQLiteUndoTCA */; }; + C45245EF2F3D353800F31BB8 /* SQLiteUndo in Frameworks */ = {isa = PBXBuildFile; productRef = C45245EE2F3D353800F31BB8 /* SQLiteUndo */; }; + C45245F12F3D353800F31BB8 /* SQLiteUndoTCA in Frameworks */ = {isa = PBXBuildFile; productRef = C45245F02F3D353800F31BB8 /* SQLiteUndoTCA */; }; C4B1976A2F33B52B001EAFC2 /* SQLiteUndo in Frameworks */ = {isa = PBXBuildFile; productRef = C4B197692F33B52B001EAFC2 /* SQLiteUndo */; }; C4B197712F33B5D9001EAFC2 /* ComposableArchitecture in Frameworks */ = {isa = PBXBuildFile; productRef = C4B197702F33B5D9001EAFC2 /* ComposableArchitecture */; }; C4B197F02F33C28A001EAFC2 /* SQLiteUndo in Frameworks */ = {isa = PBXBuildFile; productRef = C4B197EF2F33C28A001EAFC2 /* SQLiteUndo */; }; @@ -35,11 +37,13 @@ buildActionMask = 2147483647; files = ( C441694F2F38F3B100051412 /* SQLiteUndo in Frameworks */, + C45245EF2F3D353800F31BB8 /* SQLiteUndo in Frameworks */, C4B197712F33B5D9001EAFC2 /* ComposableArchitecture in Frameworks */, C4B197F22F33C28A001EAFC2 /* SQLiteUndoTCA in Frameworks */, C4B197F52F33C2B0001EAFC2 /* SQLiteUndo in Frameworks */, C4B197F02F33C28A001EAFC2 /* SQLiteUndo in Frameworks */, C44169512F38F3B100051412 /* SQLiteUndoTCA in Frameworks */, + C45245F12F3D353800F31BB8 /* SQLiteUndoTCA in Frameworks */, C4B1976A2F33B52B001EAFC2 /* SQLiteUndo in Frameworks */, C4B197F72F33C2B0001EAFC2 /* SQLiteUndoTCA in Frameworks */, ); @@ -92,6 +96,8 @@ C4B197F62F33C2B0001EAFC2 /* SQLiteUndoTCA */, C441694E2F38F3B100051412 /* SQLiteUndo */, C44169502F38F3B100051412 /* SQLiteUndoTCA */, + C45245EE2F3D353800F31BB8 /* SQLiteUndo */, + C45245F02F3D353800F31BB8 /* SQLiteUndoTCA */, ); productName = UndoForMacOS; productReference = C4B1975D2F33B502001EAFC2 /* UndoForMacOS.app */; @@ -123,7 +129,7 @@ minimizedProjectReferenceProxies = 1; packageReferences = ( C4B1976F2F33B5D9001EAFC2 /* XCRemoteSwiftPackageReference "swift-composable-architecture" */, - C441694D2F38F3B100051412 /* XCLocalSwiftPackageReference "../../sqlite-undo" */, + C45245ED2F3D353800F31BB8 /* XCLocalSwiftPackageReference "../../sqlite-undo" */, ); preferredProjectObjectVersion = 77; productRefGroup = C4B197462F33B4B4001EAFC2 /* Products */; @@ -364,7 +370,7 @@ /* End XCConfigurationList section */ /* Begin XCLocalSwiftPackageReference section */ - C441694D2F38F3B100051412 /* XCLocalSwiftPackageReference "../../sqlite-undo" */ = { + C45245ED2F3D353800F31BB8 /* XCLocalSwiftPackageReference "../../sqlite-undo" */ = { isa = XCLocalSwiftPackageReference; relativePath = "../../sqlite-undo"; }; @@ -390,6 +396,14 @@ isa = XCSwiftPackageProductDependency; productName = SQLiteUndoTCA; }; + C45245EE2F3D353800F31BB8 /* SQLiteUndo */ = { + isa = XCSwiftPackageProductDependency; + productName = SQLiteUndo; + }; + C45245F02F3D353800F31BB8 /* SQLiteUndoTCA */ = { + isa = XCSwiftPackageProductDependency; + productName = SQLiteUndoTCA; + }; C4B197692F33B52B001EAFC2 /* SQLiteUndo */ = { isa = XCSwiftPackageProductDependency; productName = SQLiteUndo; diff --git a/Examples/UndoForMacOS/UndoForMacOSApp.swift b/Examples/UndoForMacOS/UndoForMacOSApp.swift index 75a0644..bc9d1c3 100644 --- a/Examples/UndoForMacOS/UndoForMacOSApp.swift +++ b/Examples/UndoForMacOS/UndoForMacOSApp.swift @@ -38,6 +38,7 @@ struct DemoFeature { case undoManager(UndoManagingAction) case addItem case addItemInBackground + case addItemWithoutTracking case addUntrackedItem case incrementCount(Int) case incrementAll @@ -75,6 +76,17 @@ struct DemoFeature { } } + case .addItemWithoutTracking: + withErrorReporting { + try withUndoDisabled { + try database.write { db in + let nextID = (try DemoItem.all.fetchAll(db).map(\.id).max() ?? 0) + 1 + try DemoItem.insert { DemoItem(id: nextID, name: "Item \(nextID)") }.execute(db) + } + } + } + return .none + case .addUntrackedItem: withErrorReporting { try undoable("Add Untracked Item") { @@ -161,25 +173,33 @@ struct DemoView: View { } .frame(minHeight: 200) - HStack { - Button("Add Item") { - store.send(.addItem) - } - .buttonStyle(.borderedProminent) - Button("Add Item (Background)") { - store.send(.addItemInBackground) - } - .buttonStyle(.bordered) - Button("Increment All") { - store.send(.incrementAll) + VStack { + HStack { + Button("Add Item") { + store.send(.addItem) + } + .buttonStyle(.borderedProminent) + Button("Increment All") { + store.send(.incrementAll) + } + .buttonStyle(.bordered) + .disabled(store.items.isEmpty) } - .buttonStyle(.bordered) - .disabled(store.items.isEmpty) - Divider() - Button("Add Untracked Item") { - store.send(.addUntrackedItem) + HStack { + Button("Add Item without tracking") { + store.send(.addItemWithoutTracking) + } + .buttonStyle(.bordered) + Button("Add Item (Background)") { + store.send(.addItemInBackground) + } + .buttonStyle(.bordered) + Button("Add Untracked Item") { + store.send(.addUntrackedItem) + } + .buttonStyle(.bordered) } - .buttonStyle(.bordered) + .fixedSize() } } .padding() diff --git a/README.md b/README.md index 9ea0889..8c3df7f 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,28 @@ [![CI](https://github.com/latentco/sqlite-undo/actions/workflows/ci.yml/badge.svg)](https://github.com/latentco/sqlite-undo/actions/workflows/ci.yml) -SQLite-based undo/redo for Swift apps using [SQLiteData](https://github.com/pointfreeco/sqlite-data). Uses database triggers to capture changes automatically using the pattern described in [Automatic Undo/Redo Using SQLite](https://www.sqlite.org/undoredo.html) +SQLite-based undo/redo for Swift apps using [SQLiteData](https://github.com/pointfreeco/sqlite-data) and [StructuredQueries](https://github.com/pointfreeco/swift-structured-queries). Uses database triggers to automatically capture reverse SQL for all changes to tracked tables, following the pattern described in [Automatic Undo/Redo Using SQLite](https://www.sqlite.org/undoredo.html). + +Changes are grouped into barriers that represent single user actions (e.g., "Set Rating", "Delete Item"). Barriers integrate with `NSUndoManager` so undo/redo works with the standard Edit menu, keyboard shortcuts, and shake-to-undo. + +Two libraries are provided: + +- **SQLiteUndo** — core undo engine, barriers, and free functions (`undoable`, `withUndoDisabled`) +- **SQLiteUndoTCA** — [ComposableArchitecture](https://github.com/pointfreeco/swift-composable-architecture) integration for `UndoManager` wiring in SwiftUI + +## Adding SQLiteUndo as a dependency + +Add the following to your `Package.swift`: + +```swift +.package(url: "https://github.com/latentco/sqlite-undo.git", from: "0.1.0"), +``` + +Then add the product to your target's dependencies: + +```swift +.product(name: "SQLiteUndo", package: "sqlite-undo"), +``` ## Setup @@ -38,6 +59,49 @@ try await undoable("Set Rating") { } ``` +### Disabling undo tracking + +Use `withUndoDisabled` for operations that shouldn't be undoable (e.g., batch imports, programmatic state rebuilds): + +```swift +try withUndoDisabled { + try database.write { db in + try Article.insert { Article(id: 1, name: "Imported") }.execute(db) + } +} +``` + +### Application triggers + +If your app has triggers that cascade writes (e.g., clearing a flag on other rows, updating derived state), they **must** include `UndoEngine.isReplaying()` in their WHEN clause: + +```swift +Article.createTemporaryTrigger( + after: .update { $0.isPrimary }, + forEachRow: { old, new in + // Clear isPrimary on all other rows + Article.where { $0.id != new.id } + .update { $0.isPrimary = false } + }, + when: { old, new in + !UndoEngine.isReplaying() + } +) +``` + +Or in raw SQL: + +```sql +CREATE TRIGGER clear_primary +AFTER UPDATE OF "isPrimary" ON "articles" +WHEN NOT "sqliteundo_isReplaying"() +BEGIN + UPDATE "articles" SET "isPrimary" = 0 WHERE "id" != NEW."id"; +END +``` + +> **Note:** The undo system uses BEFORE triggers to capture original values and records all effects of a change (including cascades) in the undo log. During undo/redo replay, each effect is replayed individually, so cascade triggers must be suppressed to avoid corrupting the restored state. Without the `isReplaying` guard, a cascade trigger would fire again during replay and overwrite values that the undo system is trying to restore. + ### With explicit barrier management ```swift diff --git a/Sources/SQLiteUndo/UndoCoordinator.swift b/Sources/SQLiteUndo/UndoCoordinator.swift index d8a4764..0b1654b 100644 --- a/Sources/SQLiteUndo/UndoCoordinator.swift +++ b/Sources/SQLiteUndo/UndoCoordinator.swift @@ -94,7 +94,7 @@ final class UndoCoordinator: Sendable { return nil } - return try database.read { db in + return try database.write { db in guard let endSeq = try db.undoLogMaxSeq(), endSeq >= openBarrier.startSeq else { let tables = registeredTables.sorted() logger.warning( @@ -110,6 +110,14 @@ final class UndoCoordinator: Sendable { return nil } + // Reconcile duplicate entries from cascading BEFORE triggers + try db.reconcileUndoLogEntries(from: openBarrier.startSeq, to: endSeq) + + // Re-read endSeq since reconciliation may have removed entries + guard let endSeq = try db.undoLogMaxSeq(), endSeq >= openBarrier.startSeq else { + return nil + } + let barrier = UndoBarrier( id: id, name: openBarrier.name, @@ -215,20 +223,4 @@ final class UndoCoordinator: Sendable { } } } - - /// Temporarily disable undo tracking. - /// - /// Use this for bulk operations, migrations, or imports where you don't - /// want individual changes tracked. - func withUndoDisabled(_ operation: () throws -> T) throws -> T { - try database.write { db in - try UndoState.find(1).update { $0.isActive = false }.execute(db) - } - defer { - try? database.write { db in - try UndoState.find(1).update { $0.isActive = true }.execute(db) - } - } - return try operation() - } } diff --git a/Sources/SQLiteUndo/UndoEngine.swift b/Sources/SQLiteUndo/UndoEngine.swift index 6dfea8a..d8f2366 100644 --- a/Sources/SQLiteUndo/UndoEngine.swift +++ b/Sources/SQLiteUndo/UndoEngine.swift @@ -19,29 +19,30 @@ private let logger = Logger(subsystem: "SQLiteUndo", category: "UndoEngine") /// $0.defaultUndoStack = .live(windowUndoManager) /// $0.defaultUndoEngine = try! UndoEngine( /// for: $0.defaultDatabase, -/// tables: ProjectItem.self, ProjectEdit.self +/// tables: Item.self, Edit.self /// ) /// } /// ``` /// /// ## Usage /// +/// Wrap database changes in ``undoable(_:operation:)-3cgh0`` to make them undoable: +/// /// ```swift -/// @Dependency(\.defaultUndoEngine) var undoEngine +/// try undoable("Set Rating") { +/// try database.write { db in +/// try Item.find(id).update { $0.rating = rating }.execute(db) +/// } +/// } +/// ``` /// -/// // Simple operation -/// let barrierId = try undoEngine.beginBarrier("Set Rating") -/// try database.write { /* make changes */ } -/// try undoEngine.endBarrier(barrierId) +/// Use ``withUndoDisabled(_:)`` for operations that shouldn't be tracked: /// -/// // With error handling -/// do { -/// let barrierId = try undoEngine.beginBarrier("Set Rating") -/// try database.write { /* make changes */ } -/// try undoEngine.endBarrier(barrierId) -/// } catch { -/// try undoEngine.cancelBarrier(barrierId) -/// throw error +/// ```swift +/// try withUndoDisabled { +/// try database.write { db in +/// try Item.insert { Item(id: 1, name: "Imported") }.execute(db) +/// } /// } /// ``` @DependencyClient @@ -65,15 +66,41 @@ public struct UndoEngine: Sendable { /// /// - Parameter id: The barrier ID from `beginBarrier` public var cancelBarrier: @Sendable (_ id: UUID) throws -> Void +} + +/// Whether undo tracking is active. Default true; set false inside `withUndoDisabled`. +@TaskLocal var _undoIsActive = true + +/// Whether the undo system is replaying entries (undo/redo in progress). +@TaskLocal var _undoIsReplaying = false + +@DatabaseFunction("sqliteundo_isActive") +func undoIsActiveFunction() -> Bool { + _undoIsActive +} - /// Temporarily disable undo tracking for an operation. +@DatabaseFunction("sqliteundo_isReplaying") +func undoIsReplayingFunction() -> Bool { + _undoIsReplaying +} + +extension UndoEngine { + /// A SQL expression that evaluates to true when the undo system is replaying entries. /// - /// Use for migrations, bulk imports, or other operations that shouldn't - /// be individually undoable. + /// Use `!UndoEngine.isReplaying()` in application trigger WHEN clauses to suppress + /// cascading writes during undo/redo replay: /// - /// - Parameter operation: The operation to perform without tracking - public var withUndoDisabled: @Sendable (_ operation: () throws -> Void) throws -> Void = { - try $0() + /// ```swift + /// Table.createTemporaryTrigger( + /// after: .update { $0.isSelected } + /// forEachRow: { old, new in ... } + /// when: { old, new in + /// someCondition.and(!UndoEngine.isReplaying()) + /// } + /// ) + /// ``` + public static func isReplaying() -> some QueryExpression { + $undoIsReplayingFunction() } } @@ -104,7 +131,10 @@ extension UndoEngine { let registeredNames = Set(tables.map { $0.tableName }) let untrackedNames = Set(untracked.map { $0.tableName }) self = .make( - database: database, registeredTables: registeredNames, untrackedTables: untrackedNames) + database: database, + registeredTables: registeredNames, + untrackedTables: untrackedNames + ) } /// Create an UndoEngine for a database with the specified tracked tables. @@ -126,7 +156,10 @@ extension UndoEngine { let registeredNames = Set(tables.map { $0.tableName }) let untrackedNames = Set(untracked.map { $0.tableName }) self = .make( - database: database, registeredTables: registeredNames, untrackedTables: untrackedNames) + database: database, + registeredTables: registeredNames, + untrackedTables: untrackedNames + ) } private static func install(for database: any DatabaseWriter, tables: [any Table.Type]) @@ -191,9 +224,6 @@ extension UndoEngine: DependencyKey { }, cancelBarrier: { id in try coordinator.cancelBarrier(id) - }, - withUndoDisabled: { operation in - try coordinator.withUndoDisabled(operation) } ) } diff --git a/Sources/SQLiteUndo/UndoOperations.swift b/Sources/SQLiteUndo/UndoOperations.swift index 2b866c2..17c6933 100644 --- a/Sources/SQLiteUndo/UndoOperations.swift +++ b/Sources/SQLiteUndo/UndoOperations.swift @@ -60,16 +60,23 @@ extension Database { // Get current max seq before executing (new entries will be added after this) let seqBefore = try undoLogMaxSeq() ?? 0 - // Execute with triggers ENABLED - this captures the reverse SQL - for entry in entries { - logger.trace("Executing SQL: \(entry.sql)") - try #sql("\(raw: entry.sql)").execute(self) + // Execute with triggers ENABLED - this captures the reverse SQL. + // Set isReplaying so app-level triggers suppress cascading writes. + // The undo log already contains all effects (including cascades), + // so replaying them individually is sufficient. + try $_undoIsReplaying.withValue(true) { + for entry in entries { + logger.trace("Executing SQL: \(entry.sql)") + try #sql("\(raw: entry.sql)").execute(self) + } } // Get new seq range for captured entries let seqAfter = try undoLogMaxSeq() ?? seqBefore if seqAfter > seqBefore { let newRange = UndoCoordinator.SeqRange(startSeq: seqBefore + 1, endSeq: seqAfter) + // Reconcile duplicates from BEFORE triggers firing during replay + try reconcileUndoLogEntries(from: newRange.startSeq, to: newRange.endSeq) logger.debug("New seq range: \(newRange.startSeq)...\(newRange.endSeq)") return newRange } @@ -101,4 +108,80 @@ extension Database { .fetchAll(self) return Set(tableNames) } + + /// Reconcile undolog entries in a seq range to remove duplicates. + /// + /// BEFORE triggers and replay can produce multiple entries for the same row within + /// a single barrier. This keeps only the first entry (lowest seq = true original) + /// per (tableName, trackedRowid) group, with special handling: + /// - INSERT (DELETE-reverse) + DELETE (INSERT-reverse) of same row → remove both (no-op) + /// - INSERT (DELETE-reverse) + UPDATE → keep just the DELETE-reverse (undo = delete) + /// - Multiple UPDATEs → keep first (true original values) + func reconcileUndoLogEntries(from startSeq: Int, to endSeq: Int) throws { + // Fast path: check if any duplicates exist before fetching all entries + let hasDuplicates = try #sql( + """ + SELECT 1 FROM undolog + WHERE seq >= \(startSeq) AND seq <= \(endSeq) AND trackedRowid != 0 + GROUP BY tableName, trackedRowid + HAVING COUNT(*) > 1 + LIMIT 1 + """, + as: Int.self + ).fetchOne(self) + + guard hasDuplicates != nil else { return } + + let entries = + try UndoLogEntry + .where { $0.seq >= startSeq && $0.seq <= endSeq } + .order { $0.seq.asc() } + .fetchAll(self) + + var groups: [String: [UndoLogEntry]] = [:] + for entry in entries { + guard entry.trackedRowid != 0 else { continue } + let key = "\(entry.tableName):\(entry.trackedRowid)" + groups[key, default: []].append(entry) + } + + var seqsToDelete: [Int] = [] + + for (_, group) in groups { + guard group.count > 1 else { continue } + + let first = group[0] + let last = group[group.count - 1] + + let firstIsDeleteReverse = first.sql.hasPrefix("DELETE FROM") + let lastIsInsertReverse = last.sql.hasPrefix("INSERT INTO") + + if firstIsDeleteReverse && lastIsInsertReverse { + // INSERT then DELETE in same barrier → no-op, remove all + for entry in group { + seqsToDelete.append(entry.seq) + } + } else if firstIsDeleteReverse { + // INSERT then UPDATEs → keep DELETE-reverse (undo = delete), remove rest + for entry in group.dropFirst() { + seqsToDelete.append(entry.seq) + } + } else { + // First is UPDATE-reverse or INSERT-reverse (pre-existing row). + // Remove only subsequent UPDATE-reverses (cascade duplicates). + // Keep INSERT-reverses (from DELETE) since replay needs them for row re-creation. + for entry in group.dropFirst() { + if entry.sql.hasPrefix("UPDATE") { + seqsToDelete.append(entry.seq) + } + } + } + } + + if !seqsToDelete.isEmpty { + let placeholders = seqsToDelete.map { "\($0)" }.joined(separator: ",") + try #sql("DELETE FROM undolog WHERE seq IN (\(raw: placeholders))").execute(self) + logger.debug("Reconciled: removed \(seqsToDelete.count) duplicate undolog entries") + } + } } diff --git a/Sources/SQLiteUndo/UndoSchema.swift b/Sources/SQLiteUndo/UndoSchema.swift index 2a2ec02..4c8237e 100644 --- a/Sources/SQLiteUndo/UndoSchema.swift +++ b/Sources/SQLiteUndo/UndoSchema.swift @@ -11,22 +11,12 @@ struct UndoLogEntry: Sendable { var seq: Int /// The name of the table that was modified. var tableName: String + /// The rowid of the tracked row, for deduplication during reconciliation. + var trackedRowid: Int = 0 /// The SQL statement to reverse the change. var sql: String } -/// Singleton row tracking whether undo tracking is active. -/// -/// This table always contains exactly one row (id=1). -/// Stack management is handled by NSUndoManager, not stored in the database. -@Table("undoState") -struct UndoState: Sendable { - /// Always 1 (singleton constraint) - var id: Int = 1 - /// Whether undo tracking triggers are active. - var isActive: Bool = true -} - extension DatabaseWriter { func installUndoSystem() throws { try write { db in @@ -38,21 +28,14 @@ extension DatabaseWriter { CREATE TABLE undolog ( seq INTEGER PRIMARY KEY AUTOINCREMENT, tableName TEXT NOT NULL, + trackedRowid INTEGER NOT NULL DEFAULT 0, sql TEXT NOT NULL ) """ ).execute(db) - try #sql( - """ - CREATE TABLE undoState ( - id INTEGER PRIMARY KEY CHECK (id = 1), - isActive INTEGER NOT NULL DEFAULT 1 - ) - """ - ).execute(db) - - try #sql("INSERT INTO undoState (id, isActive) VALUES (1, 1)").execute(db) + db.add(function: $undoIsActiveFunction) + db.add(function: $undoIsReplayingFunction) } } } diff --git a/Sources/SQLiteUndo/UndoTracked.swift b/Sources/SQLiteUndo/UndoTracked.swift index 386ae05..f08806f 100644 --- a/Sources/SQLiteUndo/UndoTracked.swift +++ b/Sources/SQLiteUndo/UndoTracked.swift @@ -5,8 +5,8 @@ extension StructuredQueries.Table { /// Generate and install undo triggers for this table. /// /// Creates three TEMPORARY triggers (INSERT, UPDATE, DELETE) that record - /// reverse SQL into the undolog table. All triggers check the `isActive` - /// flag in undoState before recording. + /// reverse SQL into the undolog table. All triggers call the `sqliteundo_isActive()` + /// database function before recording. public static func installUndoTriggers(_ db: Database) throws { let triggers = generateUndoTriggers() for sql in triggers { @@ -31,15 +31,16 @@ extension StructuredQueries.Table { """ CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_\(table)_insert AFTER INSERT ON "\(table)" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('\(table)', 'DELETE FROM "\(table)" WHERE rowid='||NEW.rowid); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('\(table)', NEW.rowid, 'DELETE FROM "\(table)" WHERE rowid='||NEW.rowid); END """ } /// UPDATE trigger: Records an UPDATE statement with old values. + /// Uses BEFORE timing to capture true original values before cascading triggers fire. private static func generateUpdateTrigger(table: String, columns: [String]) -> String { // Build: col1='||quote(OLD.col1)||',col2='||quote(OLD.col2)||'... let setClauses = columns.map { col in @@ -48,16 +49,17 @@ extension StructuredQueries.Table { return """ CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_\(table)_update - AFTER UPDATE ON "\(table)" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + BEFORE UPDATE ON "\(table)" + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('\(table)', 'UPDATE "\(table)" SET '||\(setClauses)||' WHERE rowid='||OLD.rowid); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('\(table)', OLD.rowid, 'UPDATE "\(table)" SET '||\(setClauses)||' WHERE rowid='||OLD.rowid); END """ } /// DELETE trigger: Records an INSERT statement with old values. + /// Uses BEFORE timing to capture true original values before cascading triggers fire. private static func generateDeleteTrigger(table: String, columns: [String]) -> String { // Build column list: "col1","col2",... let columnList = columns.map { "\"\($0)\"" }.joined(separator: ",") @@ -69,11 +71,11 @@ extension StructuredQueries.Table { return """ CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_\(table)_delete - AFTER DELETE ON "\(table)" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + BEFORE DELETE ON "\(table)" + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('\(table)', 'INSERT INTO "\(table)"(rowid,\(columnList)) VALUES('||OLD.rowid||','||\(valueExpressions)||')'); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('\(table)', OLD.rowid, 'INSERT INTO "\(table)"(rowid,\(columnList)) VALUES('||OLD.rowid||','||\(valueExpressions)||')'); END """ } diff --git a/Sources/SQLiteUndo/Undoable.swift b/Sources/SQLiteUndo/Undoable.swift index b1432f6..4e3766b 100644 --- a/Sources/SQLiteUndo/Undoable.swift +++ b/Sources/SQLiteUndo/Undoable.swift @@ -1,28 +1,26 @@ import Dependencies import Foundation -/// Execute an async operation within an undoable barrier. -/// -/// Use this for database operations that should be undoable: +/// Execute an operation within an undoable barrier. /// /// ```swift -/// try await undoable("Set Rating") { -/// try await database.write { db in -/// try ProjectItem.find(id).update { $0.rating = rating }.execute(db) +/// try undoable("Set Rating") { +/// try database.write { db in +/// try Item.find(id).update { $0.rating = rating }.execute(db) /// } /// } /// ``` /// /// The barrier is automatically cancelled if the operation throws. -public func undoable( +public func undoable( _ actionName: String, - operation: @Sendable () async throws -> T -) async throws -> T { + operation: () throws -> T +) throws -> T { @Dependency(\.defaultUndoEngine) var undoEngine let barrierId = try undoEngine.beginBarrier(actionName) do { - let result = try await operation() + let result = try operation() try undoEngine.endBarrier(barrierId) return result } catch { @@ -31,28 +29,26 @@ public func undoable( } } -/// Execute a synchronous operation within an undoable barrier. -/// -/// Use this for simple, inline undoable operations: +/// Execute an async operation within an undoable barrier. /// /// ```swift -/// undoable("Set Rating") { -/// try database.write { db in -/// try ProjectItem.find(id).update { $0.rating = rating }.execute(db) +/// try await undoable("Set Rating") { +/// try await database.write { db in +/// try Item.find(id).update { $0.rating = rating }.execute(db) /// } /// } /// ``` /// /// The barrier is automatically cancelled if the operation throws. -public func undoable( +public func undoable( _ actionName: String, - operation: () throws -> T -) throws -> T { + operation: @Sendable () async throws -> T +) async throws -> T { @Dependency(\.defaultUndoEngine) var undoEngine let barrierId = try undoEngine.beginBarrier(actionName) do { - let result = try operation() + let result = try await operation() try undoEngine.endBarrier(barrierId) return result } catch { @@ -60,3 +56,41 @@ public func undoable( throw error } } + +/// Execute an operation with undo tracking disabled. +/// +/// Changes made within this block are not captured in the undo log. +/// Use this for programmatic operations that shouldn't be undoable +/// (e.g., initial app state, batch imports). +/// +/// ```swift +/// try withUndoDisabled { +/// try database.write { db in +/// try Item.insert { Item(id: 1, name: "Imported") }.execute(db) +/// } +/// } +/// ``` +public func withUndoDisabled(_ operation: () throws -> T) throws -> T { + try $_undoIsActive.withValue(false) { + try operation() + } +} + +/// Execute an async operation with undo tracking disabled. +/// +/// Changes made within this block are not captured in the undo log. +/// Use this for programmatic operations that shouldn't be undoable +/// (e.g., initial app state, batch imports). +/// +/// ```swift +/// try await withUndoDisabled { +/// try await database.write { db in +/// try Item.insert { Item(id: 1, name: "Imported") }.execute(db) +/// } +/// } +/// ``` +public func withUndoDisabled(_ operation: @Sendable () async throws -> T) async throws -> T { + try await $_undoIsActive.withValue(false) { + try await operation() + } +} diff --git a/Tests/SQLiteUndoTests/CascadeTriggerTests.swift b/Tests/SQLiteUndoTests/CascadeTriggerTests.swift new file mode 100644 index 0000000..ed46cf3 --- /dev/null +++ b/Tests/SQLiteUndoTests/CascadeTriggerTests.swift @@ -0,0 +1,305 @@ +import Foundation +import SQLiteData +import Testing + +@testable import SQLiteUndo + +@Suite(.serialized) +struct CascadeTriggerTests { + + @Suite + struct SameRowCascade { + + @Test + func undoRevertsOriginalValues() throws { + // App trigger: AFTER UPDATE OF value → sets flag=1 on the same row + // Undo should restore both value and flag to their originals + let (database, engine) = try makeCascadeDatabase(trigger: .sameRowFlag) + + try withUndoDisabled { + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'original', 0) + """) + } + } + + let barrierId = try engine.beginBarrier("Update Value") + try database.write { db in + try db.execute(sql: """ + UPDATE "cascadeItems" SET "value" = 'changed' WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId)! + + // Verify the cascade fired: flag should be 1 + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "changed") + #expect(item.flag == 1) + } + + try engine.performUndo(barrier: barrier) + + // After undo: both value and flag should be restored to originals + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "original") + #expect(item.flag == 0) + } + } + } + + @Suite + struct CrossRowCascade { + + @Test + func undoRevertsBothRows() throws { + // App trigger: AFTER UPDATE OF value ON cascadeItems + // → UPDATE cascadeItems SET flag=1 WHERE id != NEW.id + // Updating row A cascades to set flag=1 on row B. + // Undo should revert both. + let (database, engine) = try makeCascadeDatabase(trigger: .crossRowFlag) + + try withUndoDisabled { + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'A', 0); + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (2, 'B', 0); + """) + } + } + + let barrierId = try engine.beginBarrier("Update A") + try database.write { db in + try db.execute(sql: """ + UPDATE "cascadeItems" SET "value" = 'A-changed' WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId)! + + // Verify cascade: row A updated, row B got flag=1 + try database.read { db in + let a = try CascadeItem.find(1).fetchOne(db)! + #expect(a.value == "A-changed") + let b = try CascadeItem.find(2).fetchOne(db)! + #expect(b.flag == 1) + } + + try engine.performUndo(barrier: barrier) + + // After undo: both rows should be restored + try database.read { db in + let a = try CascadeItem.find(1).fetchOne(db)! + #expect(a.value == "A") + #expect(a.flag == 0) + let b = try CascadeItem.find(2).fetchOne(db)! + #expect(b.value == "B") + #expect(b.flag == 0) + } + } + } + + @Suite + struct EdgeCases { + + @Test + func insertThenDeleteIsNoOp() throws { + let (database, engine) = try makeCascadeDatabase(trigger: .none) + + let barrierId = try engine.beginBarrier("Insert Then Delete") + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'temp', 0) + """) + try db.execute(sql: """ + DELETE FROM "cascadeItems" WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId) + + // The barrier may be nil (if reconciliation removes all entries) + // or non-nil but undo should be a no-op + if let barrier { + try engine.performUndo(barrier: barrier) + } + + try database.read { db in + let count = try CascadeItem.all.fetchCount(db) + #expect(count == 0) + } + } + + @Test + func insertThenUpdateUndoDeletesRow() throws { + let (database, engine) = try makeCascadeDatabase(trigger: .none) + + let barrierId = try engine.beginBarrier("Insert Then Update") + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'initial', 0) + """) + try db.execute(sql: """ + UPDATE "cascadeItems" SET "value" = 'modified' WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId)! + + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "modified") + } + + // Undo should delete the row (reverse of the INSERT) + try engine.performUndo(barrier: barrier) + + try database.read { db in + let count = try CascadeItem.all.fetchCount(db) + #expect(count == 0) + } + } + + @Test + func updateThenDeleteUndoReInsertsOriginal() throws { + let (database, engine) = try makeCascadeDatabase(trigger: .none) + + try withUndoDisabled { + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'original', 0) + """) + } + } + + let barrierId = try engine.beginBarrier("Update Then Delete") + try database.write { db in + try db.execute(sql: """ + UPDATE "cascadeItems" SET "value" = 'modified' WHERE "id" = 1 + """) + try db.execute(sql: """ + DELETE FROM "cascadeItems" WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId)! + + try database.read { db in + let count = try CascadeItem.all.fetchCount(db) + #expect(count == 0) + } + + // Undo should re-insert with original pre-update values + try engine.performUndo(barrier: barrier) + + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "original") + #expect(item.flag == 0) + } + } + + @Test + func undoRedoRoundTrip() throws { + let (database, engine) = try makeCascadeDatabase(trigger: .sameRowFlag) + + try withUndoDisabled { + try database.write { db in + try db.execute(sql: """ + INSERT INTO "cascadeItems" ("id", "value", "flag") VALUES (1, 'original', 0) + """) + } + } + + let barrierId = try engine.beginBarrier("Update") + try database.write { db in + try db.execute(sql: """ + UPDATE "cascadeItems" SET "value" = 'changed' WHERE "id" = 1 + """) + } + let barrier = try engine.endBarrier(barrierId)! + + // Undo + try engine.performUndo(barrier: barrier) + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "original") + #expect(item.flag == 0) + } + + // Redo + try engine.performRedo(barrier: barrier) + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "changed") + #expect(item.flag == 1) + } + + // Undo again + try engine.performUndo(barrier: barrier) + try database.read { db in + let item = try CascadeItem.find(1).fetchOne(db)! + #expect(item.value == "original") + #expect(item.flag == 0) + } + } + } +} + +@Table("cascadeItems") +private struct CascadeItem: Identifiable { + @Column(primaryKey: true) var id: Int + var value: String = "" + var flag: Int = 0 +} + +private enum CascadeTrigger { + case none + case sameRowFlag + case crossRowFlag +} + +private func makeCascadeDatabase( + trigger: CascadeTrigger +) throws -> (any DatabaseWriter, UndoCoordinator) { + let database = try DatabaseQueue(configuration: Configuration()) + + try database.write { db in + try db.execute(sql: """ + CREATE TABLE "cascadeItems" ( + "id" INTEGER PRIMARY KEY, + "value" TEXT NOT NULL DEFAULT '', + "flag" INTEGER NOT NULL DEFAULT 0 + ) + """) + } + + try database.installUndoSystem() + try database.write { db in + for sql in CascadeItem.generateUndoTriggers() { + try db.execute(sql: sql) + } + } + + if trigger != .none { + try database.write { db in + let whereClause: String + switch trigger { + case .none: + fatalError() + case .sameRowFlag: + whereClause = "rowid = NEW.rowid" + case .crossRowFlag: + whereClause = "id != NEW.id" + } + try db.execute(sql: """ + CREATE TEMPORARY TRIGGER cascade_trigger + AFTER UPDATE OF "value" ON "cascadeItems" + WHEN NOT "sqliteundo_isReplaying"() + BEGIN + UPDATE "cascadeItems" SET "flag" = 1 WHERE \(whereClause); + END + """) + } + } + + return (database, UndoCoordinator(database: database)) +} diff --git a/Tests/SQLiteUndoTests/UndoEngineTests.swift b/Tests/SQLiteUndoTests/UndoEngineTests.swift index c650639..91d6916 100644 --- a/Tests/SQLiteUndoTests/UndoEngineTests.swift +++ b/Tests/SQLiteUndoTests/UndoEngineTests.swift @@ -23,26 +23,26 @@ enum UndoEngineTests { """ CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_testRecords_insert AFTER INSERT ON "testRecords" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('testRecords', 'DELETE FROM "testRecords" WHERE rowid='||NEW.rowid); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('testRecords', NEW.rowid, 'DELETE FROM "testRecords" WHERE rowid='||NEW.rowid); END CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_testRecords_update - AFTER UPDATE ON "testRecords" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + BEFORE UPDATE ON "testRecords" + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('testRecords', 'UPDATE "testRecords" SET '||'"id"='||quote(OLD."id")||','||'"name"='||quote(OLD."name")||','||'"value"='||quote(OLD."value")||' WHERE rowid='||OLD.rowid); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('testRecords', OLD.rowid, 'UPDATE "testRecords" SET '||'"id"='||quote(OLD."id")||','||'"name"='||quote(OLD."name")||','||'"value"='||quote(OLD."value")||' WHERE rowid='||OLD.rowid); END CREATE TEMPORARY TRIGGER IF NOT EXISTS _undo_testRecords_delete - AFTER DELETE ON "testRecords" - WHEN (SELECT isActive FROM undoState WHERE id = 1) + BEFORE DELETE ON "testRecords" + WHEN "sqliteundo_isActive"() BEGIN - INSERT INTO undolog(tableName, sql) - VALUES('testRecords', 'INSERT INTO "testRecords"(rowid,"id","name","value") VALUES('||OLD.rowid||','||quote(OLD."id")||','||quote(OLD."name")||','||quote(OLD."value")||')'); + INSERT INTO undolog(tableName, trackedRowid, sql) + VALUES('testRecords', OLD.rowid, 'INSERT INTO "testRecords"(rowid,"id","name","value") VALUES('||OLD.rowid||','||quote(OLD."id")||','||quote(OLD."name")||','||quote(OLD."value")||')'); END """ } @@ -129,7 +129,7 @@ enum UndoEngineTests { func undoUpdate() throws { let (database, engine) = try makeTestDatabaseWithUndo() - try engine.withUndoDisabled { + try withUndoDisabled { try database.write { db in try TestRecord.insert { TestRecord(id: 1, name: "Original", value: 10) }.execute(db) } @@ -163,7 +163,7 @@ enum UndoEngineTests { func undoDelete() throws { let (database, engine) = try makeTestDatabaseWithUndo() - try engine.withUndoDisabled { + try withUndoDisabled { try database.write { db in try TestRecord.insert { TestRecord(id: 1, name: "ToDelete", value: 42) }.execute(db) } @@ -193,7 +193,7 @@ enum UndoEngineTests { func redo() throws { let (database, engine) = try makeTestDatabaseWithUndo() - try engine.withUndoDisabled { + try withUndoDisabled { try database.write { db in try TestRecord.insert { TestRecord(id: 1, name: "Test", value: nil) }.execute(db) } @@ -247,13 +247,68 @@ enum UndoEngineTests { } @Suite - struct DisabledTrackingTests { + struct ReplayStateTests { @Test - func withUndoDisabled() throws { + func isReplayingTrueDuringUndo() throws { let (database, engine) = try makeTestDatabaseWithUndo() - try engine.withUndoDisabled { + // Create an audit table and a trigger that only fires when NOT replaying + try database.write { db in + try db.execute(sql: """ + CREATE TABLE "auditLog" ("id" INTEGER PRIMARY KEY AUTOINCREMENT, "action" TEXT NOT NULL) + """) + try db.execute(sql: """ + CREATE TEMPORARY TRIGGER audit_insert + AFTER INSERT ON "testRecords" + WHEN NOT "sqliteundo_isReplaying"() + BEGIN + INSERT INTO "auditLog"("action") VALUES('insert ' || NEW."name"); + END + """) + try db.execute(sql: """ + CREATE TEMPORARY TRIGGER audit_delete + AFTER DELETE ON "testRecords" + WHEN NOT "sqliteundo_isReplaying"() + BEGIN + INSERT INTO "auditLog"("action") VALUES('delete ' || OLD."name"); + END + """) + } + + // Normal insert — trigger should fire + let barrierId = try engine.beginBarrier("Insert") + try database.write { db in + try TestRecord.insert { TestRecord(id: 1, name: "Alice") }.execute(db) + } + let barrier = try engine.endBarrier(barrierId)! + + try database.read { db in + let actions = try String.fetchAll(db, sql: "SELECT action FROM auditLog ORDER BY id") + #expect(actions == ["insert Alice"]) + } + + // Undo — trigger should NOT fire (isReplaying is true) + try engine.performUndo(barrier: barrier) + + try database.read { db in + let count = try TestRecord.all.fetchCount(db) + #expect(count == 0, "Row should be deleted by undo") + + let actions = try String.fetchAll(db, sql: "SELECT action FROM auditLog ORDER BY id") + #expect(actions == ["insert Alice"], "No new audit entry during replay") + } + } + } + + @Suite + struct DisabledTrackingTests { + + @Test + func disablesUndoTracking() throws { + let (database, _) = try makeTestDatabaseWithUndo() + + try withUndoDisabled { try database.write { db in try TestRecord.insert { TestRecord(id: 1, name: "Untracked") }.execute(db) } @@ -340,7 +395,7 @@ enum UndoEngineTests { @Dependency(\.defaultDatabase) var database @Dependency(\.defaultUndoEngine) var undoEngine - try undoEngine.withUndoDisabled { + try withUndoDisabled { try database.write { db in try TestRecord.insert { TestRecord(id: 1, name: "Original") }.execute(db) } @@ -414,7 +469,8 @@ enum UndoEngineTests { #expect(try database.read { db in try TestRecord.all.fetchCount(db) } == 1) #expect( try database.read { db in try TestRecord.find(1).fetchOne(db) } != nil, - "Item 1 should be back after first redo") + "Item 1 should be back after first redo" + ) // Redo should bring back item 2 #expect(testUndoManager.redoActionName == "Create Item 2") @@ -422,7 +478,8 @@ enum UndoEngineTests { #expect(try database.read { db in try TestRecord.all.fetchCount(db) } == 2) #expect( try database.read { db in try TestRecord.find(2).fetchOne(db) } != nil, - "Item 2 should be back after second redo") + "Item 2 should be back after second redo" + ) } }