Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Distributions = "0.25"
JSON = "0.18, 0.19, 0.20, 0.21, 1"
LogDensityProblems = "2.1"
LogExpFunctions = "0.3"
Mooncake = "0.4"
Mooncake = "0.4, 0.5"
OrderedCollections = "1.8"
OrdinaryDiffEq = "6.90"
PosteriorDB = "0.5, 0.6"
Expand Down
25 changes: 25 additions & 0 deletions src/slic_stan/builtin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ end
rep_vector
rep_matrix
linspaced_array
linspaced_int_array
robust_linspaced_int_array
linspaced_vector
to_array_1d
to_array_2d
Expand All @@ -100,6 +102,8 @@ end
append_col
diag_matrix
mdivide_left_tri_low
one_hot_vector
sd
cumulative_sum
log_sum_exp
lgamma
Expand All @@ -123,6 +127,8 @@ autokwargs(::CanonicalExpr{typeof(von_mises)}) = (;lower=0, upper=2pi)
autokwargs(x::CanonicalExpr{typeof(uniform)}) = (;lower=x.args[1], upper=x.args[2])
autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,scaled_inv_chi_square,exponential,gamma,inv_gamma,weibull,frechet,rayleigh,loglogistic))...}}) = (;lower=0.)

import Statistics

@deffun begin
reduce_sum(args...)::real
reduce_sum_static(args...)::real
Expand All @@ -137,6 +143,8 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,
positive_infinity()::real
negative_infinity()::real
reject(x)::anything
Base.log1p(x::real)::real
Base.inv(::vector[n])::vector[n]
Base.print(x)::anything
Base.size(x)::int
Base.range(start::int, stop::int)::vector[stop]
Expand All @@ -145,6 +153,7 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,
Base.sum(x::int[m,n])::int
Base.sum(x::int[m,n,o])::int
Base.:\(A::matrix[m, m], b::vector[m])::vector[m]
Statistics.mean(x)::real
dims(x::anything[_])::int[1]
dims(x::anything[_, _])::int[2]
dims(x::anything[_, _, _])::int[3]
Expand All @@ -154,11 +163,24 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,
cumulative_sum(x::real[m])::real[m]
cumulative_sum(x::vector[m])::vector[m]
diag_matrix(x::anything[n])::matrix[n,n]
sd(x)::real
one_hot_vector(n, k)::vector[n]
mdivide_left_tri_low(::matrix[m,m], ::vector[m])::vector[m]
mdivide_left_tri_low(::matrix[m,m], ::matrix[m,n])::matrix[m,n]
linspaced_array(n, x, y)::real[n]
linspaced_int_array(n, args...)::int[n]
robust_linspaced_int_array(n, args...)::int[n] = if n == 0
rv::int[n]
rv
else
linspaced_int_array(n, args...)
end
linspaced_vector(n, x, y)::vector[n]
to_matrix(v, m, n)::matrix[m,n]
rep_array(x::int, n)::int[n]
rep_array(x::int, m, n)::int[m, n]
rep_array(x::real, n)::real[n]
rep_array(x::real, m, n)::real[m, n]
rep_vector(v, n)::vector[n]
rep_matrix(v::vector[m], n)::matrix[m, n]
rep_matrix(x::real, m, n)::matrix[m,n]
Expand Down Expand Up @@ -268,6 +290,8 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square,
normal_lcdf(args...)
normal_lccdf(args...)

append_row(x, y, z, args...) = append_row(append_row(x, y), z, args...)
append_col(x, y, z, args...) = append_col(append_col(x, y), z, args...)

Base.invperm(x::int[n])::int[n] = begin
rv = rep_array(0, n)
Expand Down Expand Up @@ -430,6 +454,7 @@ end
(real[n],)=>int
end
typeof(log_sum_exp) => begin
(real, real) => real
(real[n], ) => real
(matrix[m,n], ) => real
(row_vector[n], ) => real
Expand Down
7 changes: 6 additions & 1 deletion src/slic_stan/slic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ meta(x::StanModel) = x.meta
vars(x::StanModel) = x.vars
blocks(x::StanModel) = x.blocks
remake(x::StanModel; kwargs...) = StanModel((;x.meta..., kwargs...), x.vars, x.blocks)
var(x::StanModel, name) = error()#vars(x)[name]
# var(x::StanModel, name) = error()#vars(x)[name]
block(x::StanModel, name) = blocks(x)[name]
Base.getindex(x::StanModel, name) = getindex(vars(x), name)
Base.setindex!(x::StanModel, value, name) = setindex!(vars(x), value, name)
Expand Down Expand Up @@ -224,6 +224,11 @@ stan_type(expr, value::AbstractVector{<:Integer}; kwargs...) = StanType(
stan_expr.((Symbol(expr, "_n"), ), size(value));
value, kwargs..., qual=:data
)
stan_type(expr, value::AbstractMatrix{<:Integer}; kwargs...) = StanType(
types.int,
stan_expr.((Symbol(expr, "_m"), Symbol(expr, "_n"), ), size(value));
value, kwargs..., qual=:data
)
stan_type(expr, value::Function; kwargs...) = StanType(types.func{typeof(value)}; value, qual=:data, kwargs...)
stan_call(f, args...) = stan_expr(CanonicalExpr(f, map(stan_expr, args)...))
stan_expr(x::StanExpr; kwargs...) = weak_remake(x; kwargs...)
Expand Down
Loading