diff --git a/tiledb/core.cc b/tiledb/core.cc index 98c0320cba..05887b3004 100644 --- a/tiledb/core.cc +++ b/tiledb/core.cc @@ -1524,6 +1524,14 @@ class PyQuery { for (auto& buffer_name : buffers_order_) { BufferInfo& buffer_info = buffers_.at(buffer_name); + // Convert validity to bitmap BEFORE creating BufferHolder + int64_t null_count = 0; + if (is_nullable(buffer_name)) { + null_count = count_zeros(buffer_info.validity); + buffer_info.validity = uint8_bool_to_uint8_bitmap( + buffer_info.validity); + } + auto buffer_holder = new BufferHolder( buffer_info.data, buffer_info.validity, buffer_info.offsets); @@ -1538,11 +1546,7 @@ class PyQuery { buffer_holder); if (is_nullable(buffer_name)) { - // count zeros before converting to bitmap - c_pa_array.null_count = count_zeros(buffer_info.validity); - // convert to bitmap - buffer_info.validity = uint8_bool_to_uint8_bitmap( - buffer_info.validity); + c_pa_array.null_count = null_count; c_pa_array.buffers[0] = buffer_info.validity.data(); c_pa_array.n_buffers = is_var(buffer_name) ? 3 : 2; c_pa_schema.flags |= ARROW_FLAG_NULLABLE; diff --git a/tiledb/tests/test_core.py b/tiledb/tests/test_core.py index f25ff466b7..b236a0f879 100644 --- a/tiledb/tests/test_core.py +++ b/tiledb/tests/test_core.py @@ -2,6 +2,7 @@ import random import numpy as np +import pytest from numpy.testing import assert_array_equal import tiledb @@ -157,3 +158,32 @@ def test_import_buffer(self): self.assertTrue("foo" in r) self.assertTrue("str" not in r) del q + + def test_nullable_arrow_buffer(self): + # BufferHolder must hold reference to converted bitmap, not original. + # Corrupted validity buffer causes wrong null positions in .to_pandas(). + pytest.importorskip("pandas") + pyarrow = pytest.importorskip("pyarrow") + + def _read_arrow(uri): + with tiledb.open(uri, "r") as A: + q = core.PyQuery(A.ctx, A, ("a",), (), 0, True) + sub = tiledb.Subarray(A) + sub.add_dim_range(0, (0, 4)) + q.set_subarray(sub) + q.submit() + return q._buffers_to_pa_table() + + uri = self.path("test_nullable_arrow_buffer") + dom = tiledb.Domain(tiledb.Dim("d", domain=(0, 4), tile=1, dtype=np.uint64)) + attr = tiledb.Attr("a", dtype="ascii", var=True, nullable=True) + tiledb.Array.create( + uri, tiledb.ArraySchema(domain=dom, attrs=[attr], sparse=True) + ) + + with tiledb.open(uri, "w") as A: + A[np.arange(5)] = {"a": pyarrow.array(["x", "y", None, None, ""])} + + df = _read_arrow(uri).to_pandas() + + assert df["a"].isna().tolist() == [False, False, True, True, False]