diff --git a/lib/sqlite3.ml b/lib/sqlite3.ml index 500cdbc..ae7b5d5 100644 --- a/lib/sqlite3.ml +++ b/lib/sqlite3.ml @@ -343,6 +343,9 @@ external data_count : stmt -> (int[@untagged]) external column_count : stmt -> (int[@untagged]) = "caml_sqlite3_column_count_bc" "caml_sqlite3_column_count" +external column_is_null : stmt -> (int[@untagged]) -> bool + = "caml_sqlite3_column_is_null_bc" "caml_sqlite3_column_is_null" + external column_blob : stmt -> (int[@untagged]) -> string = "caml_sqlite3_column_blob_bc" "caml_sqlite3_column_blob" diff --git a/lib/sqlite3.mli b/lib/sqlite3.mli index e3776e9..ccc4ef6 100644 --- a/lib/sqlite3.mli +++ b/lib/sqlite3.mli @@ -495,6 +495,15 @@ val column_count : stmt -> int @raise SqliteError if the statement is invalid. *) +val column_is_null : stmt -> int -> bool +(** [column_is_null stmt n] + @return + [true] if the data in column [n] of the result of the last step of + statement [stmt] is NULL, [false] otherwise. + + @raise RangeError if [n] is out of range. + @raise SqliteError if the statement is invalid. *) + val column : stmt -> int -> Data.t (** [column stmt n] @return diff --git a/lib/sqlite3_stubs.c b/lib/sqlite3_stubs.c index 3d0bab7..024691c 100644 --- a/lib/sqlite3_stubs.c +++ b/lib/sqlite3_stubs.c @@ -1137,6 +1137,18 @@ CAMLprim value caml_sqlite3_column_count_bc(value v_stmt) { return Val_int(caml_sqlite3_column_count(v_stmt)); } +/* column_is_null */ + +CAMLprim value caml_sqlite3_column_is_null(value v_stmt, intnat pos) { + sqlite3_stmt *stmt = safe_get_stmtw("column_is_null", v_stmt)->stmt; + range_check(pos, sqlite3_column_count(stmt)); + return Val_bool(sqlite3_column_type(stmt, pos) == SQLITE_NULL); +} + +CAMLprim value caml_sqlite3_column_is_null_bc(value v_stmt, value v_pos) { + return caml_sqlite3_column_is_null(v_stmt, Int_val(v_pos)); +} + /* column_blob */ CAMLprim value caml_sqlite3_column_blob(value v_stmt, intnat pos) { diff --git a/test/test_values.ml b/test/test_values.ml index ea0d225..85a8813 100644 --- a/test/test_values.ml +++ b/test/test_values.ml @@ -65,7 +65,8 @@ let%test "test_values" = assert (column_int32 select_stmt 1 = 0l); assert (column_int64 select_stmt 2 = 0L); assert (column_double select_stmt 3 = 0.0); - assert (column_bool select_stmt 4 = false)); + assert (column_bool select_stmt 4 = false); + assert (column_is_null select_stmt 4)); (* Clean up *) ignore (finalize insert_stmt);