diff --git a/Project.toml b/Project.toml index 2a71587..4470748 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ Distributions = "0.25" JSON = "0.18, 0.19, 0.20, 0.21, 1" LogDensityProblems = "2.1" LogExpFunctions = "0.3" -Mooncake = "0.4" +Mooncake = "0.4, 0.5" OrderedCollections = "1.8" OrdinaryDiffEq = "6.90" PosteriorDB = "0.5, 0.6" diff --git a/src/slic_stan/builtin.jl b/src/slic_stan/builtin.jl index 0eedd90..7c9b3fb 100644 --- a/src/slic_stan/builtin.jl +++ b/src/slic_stan/builtin.jl @@ -78,6 +78,8 @@ end rep_vector rep_matrix linspaced_array + linspaced_int_array + robust_linspaced_int_array linspaced_vector to_array_1d to_array_2d @@ -100,6 +102,8 @@ end append_col diag_matrix mdivide_left_tri_low + one_hot_vector + sd cumulative_sum log_sum_exp lgamma @@ -123,6 +127,8 @@ autokwargs(::CanonicalExpr{typeof(von_mises)}) = (;lower=0, upper=2pi) autokwargs(x::CanonicalExpr{typeof(uniform)}) = (;lower=x.args[1], upper=x.args[2]) autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,scaled_inv_chi_square,exponential,gamma,inv_gamma,weibull,frechet,rayleigh,loglogistic))...}}) = (;lower=0.) +import Statistics + @deffun begin reduce_sum(args...)::real reduce_sum_static(args...)::real @@ -137,6 +143,8 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, positive_infinity()::real negative_infinity()::real reject(x)::anything + Base.log1p(x::real)::real + Base.inv(::vector[n])::vector[n] Base.print(x)::anything Base.size(x)::int Base.range(start::int, stop::int)::vector[stop] @@ -145,6 +153,7 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, Base.sum(x::int[m,n])::int Base.sum(x::int[m,n,o])::int Base.:\(A::matrix[m, m], b::vector[m])::vector[m] + Statistics.mean(x)::real dims(x::anything[_])::int[1] dims(x::anything[_, _])::int[2] dims(x::anything[_, _, _])::int[3] @@ -154,11 +163,24 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, cumulative_sum(x::real[m])::real[m] cumulative_sum(x::vector[m])::vector[m] diag_matrix(x::anything[n])::matrix[n,n] + sd(x)::real + one_hot_vector(n, k)::vector[n] + mdivide_left_tri_low(::matrix[m,m], ::vector[m])::vector[m] + mdivide_left_tri_low(::matrix[m,m], ::matrix[m,n])::matrix[m,n] linspaced_array(n, x, y)::real[n] + linspaced_int_array(n, args...)::int[n] + robust_linspaced_int_array(n, args...)::int[n] = if n == 0 + rv::int[n] + rv + else + linspaced_int_array(n, args...) + end linspaced_vector(n, x, y)::vector[n] to_matrix(v, m, n)::matrix[m,n] rep_array(x::int, n)::int[n] + rep_array(x::int, m, n)::int[m, n] rep_array(x::real, n)::real[n] + rep_array(x::real, m, n)::real[m, n] rep_vector(v, n)::vector[n] rep_matrix(v::vector[m], n)::matrix[m, n] rep_matrix(x::real, m, n)::matrix[m,n] @@ -268,6 +290,8 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, normal_lcdf(args...) normal_lccdf(args...) + append_row(x, y, z, args...) = append_row(append_row(x, y), z, args...) + append_col(x, y, z, args...) = append_col(append_col(x, y), z, args...) Base.invperm(x::int[n])::int[n] = begin rv = rep_array(0, n) @@ -430,6 +454,7 @@ end (real[n],)=>int end typeof(log_sum_exp) => begin + (real, real) => real (real[n], ) => real (matrix[m,n], ) => real (row_vector[n], ) => real diff --git a/src/slic_stan/slic.jl b/src/slic_stan/slic.jl index 6a9b4e3..076eae0 100644 --- a/src/slic_stan/slic.jl +++ b/src/slic_stan/slic.jl @@ -91,7 +91,7 @@ meta(x::StanModel) = x.meta vars(x::StanModel) = x.vars blocks(x::StanModel) = x.blocks remake(x::StanModel; kwargs...) = StanModel((;x.meta..., kwargs...), x.vars, x.blocks) -var(x::StanModel, name) = error()#vars(x)[name] +# var(x::StanModel, name) = error()#vars(x)[name] block(x::StanModel, name) = blocks(x)[name] Base.getindex(x::StanModel, name) = getindex(vars(x), name) Base.setindex!(x::StanModel, value, name) = setindex!(vars(x), value, name) @@ -224,6 +224,11 @@ stan_type(expr, value::AbstractVector{<:Integer}; kwargs...) = StanType( stan_expr.((Symbol(expr, "_n"), ), size(value)); value, kwargs..., qual=:data ) +stan_type(expr, value::AbstractMatrix{<:Integer}; kwargs...) = StanType( + types.int, + stan_expr.((Symbol(expr, "_m"), Symbol(expr, "_n"), ), size(value)); + value, kwargs..., qual=:data +) stan_type(expr, value::Function; kwargs...) = StanType(types.func{typeof(value)}; value, qual=:data, kwargs...) stan_call(f, args...) = stan_expr(CanonicalExpr(f, map(stan_expr, args)...)) stan_expr(x::StanExpr; kwargs...) = weak_remake(x; kwargs...)