Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/ArrowTypes/src/ArrowTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ export ArrowKind,
toarrow,
arrowname,
fromarrow,
ToArrow
ToArrow,
registertype!

"""
ArrowTypes.ArrowKind(T)
Expand Down Expand Up @@ -285,6 +286,36 @@ arrowname(::Type{IPv6}) = IPV6_SYMBOL
JuliaType(::Val{IPV6_SYMBOL}) = IPv6
fromarrow(::Type{IPv6}, x::NTuple{16,UInt8}) = IPv6(_cast(UInt128, x))

# Enum support
const ENUM_SYMBOL = Symbol("JuliaLang.Enum")
const ENUM_TYPES = Dict{String,DataType}()

"""
registertype!(::Type{T}) where {T<:Enum}
Register an `Enum` type for deserialization. This is only needed when reading
Arrow data in a session that did not write it (i.e. read-only scenarios).
During writing, enum types are registered automatically via `arrowmetadata`.
"""
function registertype!(::Type{T}) where {T<:Enum}
ENUM_TYPES[string(nameof(T))] = T
return T
end

ArrowType(::Type{<:Enum}) = Int32
toarrow(x::Enum) = Int32(Integer(x))
arrowname(::Type{<:Enum}) = ENUM_SYMBOL

function arrowmetadata(::Type{T}) where {T<:Enum}
key = string(nameof(T))
ENUM_TYPES[key] = T
return key
end

JuliaType(::Val{ENUM_SYMBOL}, S, meta::AbstractString) = get(ENUM_TYPES, meta, nothing)
fromarrow(::Type{T}, x::Integer) where {T<:Enum} = T(x)
default(::Type{T}) where {T<:Enum} = typemin(T)

function _cast(::Type{Y}, x)::Y where {Y}
y = Ref{Y}()
_unsafe_cast!(y, Ref(x), 1)
Expand Down
41 changes: 41 additions & 0 deletions src/ArrowTypes/test/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,45 @@ end
@test isequal(x, [missing])
end
end

@testset "Enum" begin
@enum Fruit apple=0 banana=1 cherry=2

# ArrowType and toarrow
@test ArrowTypes.ArrowType(Fruit) == Int32
@test ArrowTypes.toarrow(apple) === Int32(0)
@test ArrowTypes.toarrow(cherry) === Int32(2)

# arrowname
@test ArrowTypes.arrowname(Fruit) === ArrowTypes.ENUM_SYMBOL
@test ArrowTypes.hasarrowname(Fruit)

# arrowmetadata registers the type and returns the name
meta = ArrowTypes.arrowmetadata(Fruit)
@test meta == "Fruit"
@test ArrowTypes.ENUM_TYPES["Fruit"] === Fruit

# JuliaType lookup via registry
@test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM_SYMBOL), Int32, "Fruit") === Fruit
@test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM_SYMBOL), Int32, "NoSuchEnum") === nothing

# fromarrow
@test ArrowTypes.fromarrow(Fruit, Int32(0)) === apple
@test ArrowTypes.fromarrow(Fruit, Int32(2)) === cherry

# default
@test ArrowTypes.default(Fruit) === apple

# Manual registertype!
@enum Planet mercury=0 venus=1 earth=2
# Not yet registered (unless arrowmetadata was called)
delete!(ArrowTypes.ENUM_TYPES, "Planet")
@test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM_SYMBOL), Int32, "Planet") === nothing
ArrowTypes.registertype!(Planet)
@test ArrowTypes.JuliaType(Val(ArrowTypes.ENUM_SYMBOL), Int32, "Planet") === Planet

# Union{Enum, Missing} passthrough
@test ArrowTypes.arrowname(Union{Fruit,Missing}) === ArrowTypes.ENUM_SYMBOL
@test ArrowTypes.arrowmetadata(Union{Fruit,Missing}) == "Fruit"
end
end
35 changes: 35 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,41 @@ end
)
@test t.reject_reason[end] == "POST_ONLY"
end

@testset "Enum roundtrip" begin
@enum Direction north=0 south=1 east=2 west=3

# Basic roundtrip
orig = [north, south, east, west, north]
tbl = Arrow.Table(Arrow.tobuffer((dir=orig,)))
@test tbl.dir == orig
@test eltype(tbl.dir) == Direction

# Union{Enum, Missing}
orig_m = Union{Direction,Missing}[north, missing, east, missing, west]
tbl2 = Arrow.Table(Arrow.tobuffer((dir=orig_m,)))
@test isequal(tbl2.dir, orig_m)
@test eltype(tbl2.dir) == Union{Direction,Missing}

# Multiple enum columns
@enum Priority low=0 medium=1 high=2
orig_d = [north, south, east]
orig_p = [low, high, medium]
tbl3 = Arrow.Table(Arrow.tobuffer((dir=orig_d, pri=orig_p)))
@test tbl3.dir == orig_d
@test eltype(tbl3.dir) == Direction
@test tbl3.pri == orig_p
@test eltype(tbl3.pri) == Priority

# Multiple record batches
orig_long = repeat([north, south, east, west], 100)
io = IOBuffer()
Arrow.write(io, (dir=orig_long,); file=false)
seekstart(io)
tbl4 = Arrow.Table(io)
@test tbl4.dir == orig_long
@test eltype(tbl4.dir) == Direction
end
end # @testset "misc"

@testset "DataAPI.metadata" begin
Expand Down