diff --git a/src/MarkovModels.jl b/src/MarkovModels.jl index b4bbc7d..b12cc85 100644 --- a/src/MarkovModels.jl +++ b/src/MarkovModels.jl @@ -38,9 +38,15 @@ export # Inference expand, + compile, αrecursion, βrecursion, - pdfposteriors + pdfposteriors, + CompiledFSM, + + # LBP + FactorialFSM, + CompiledFactorialFSM include("utils.jl") include("fsm.jl") @@ -49,6 +55,7 @@ include("algorithms.jl") include("lmfsm.jl") include("linalg.jl") include("inference.jl") +include("lbp_inference.jl") #export maxstateposteriors diff --git a/src/inference.jl b/src/inference.jl index 8e497cc..9369688 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT """ - expand(V::AbstractMatrix{K}, seqlength = size(lhs, 2)) where K + expand(V::AbstractMatrix{K}, seqlength = size(V, 2)) where K Expand the ``D x N`` matrix of likelihoods `V` to a ``D+1 x N+1`` matrix `V̂`. This function is to prepare the matrix of likelihood to @@ -58,17 +58,27 @@ function βrecursion(T̂::AbstractMatrix{K}, lhs::AbstractMatrix{K}) where K B end +""" + pdfposteriors(fsm::FSM{K}, V̂s, Ĉs) where K + +Compute pdf posteriors. + +args: + fsm - batched FSM + V̂s - vector of pdf likelihoods of shape (n_pdfs x N) + Ĉs - vector of state-pdf mappings +""" function pdfposteriors(fsm::FSM{K}, V̂s, Ĉs) where K - V̂ = vcat(V̂s...) + V̂ = vcat(V̂s...) # B*n_pdfs x N V̂k = copyto!(similar(V̂, K), V̂) - Ĉ = blockdiag(Ĉs...) - ĈV̂ = (Ĉ * V̂k) + Ĉ = blockdiag(Ĉs...) # B*n_states x B*n_pdfs + ĈV̂ = (Ĉ * V̂k) # B*n_states x N state_A = αrecursion(fsm.α̂, fsm.T̂', ĈV̂) state_B = βrecursion(fsm.T̂, ĈV̂) - state_AB = broadcast!(*, state_A, state_A, state_B) - AB = Ĉ' * state_AB - Ẑ = permutedims(reshape(AB, :, length(V̂s), size(V̂, 2)), (2, 1, 3)) - sums = sum(Ẑ, dims = 2) + state_AB = broadcast!(*, state_A, state_A, state_B) # B*n_states x N + AB = Ĉ' * state_AB # B*n_pdfs x N + Ẑ = permutedims(reshape(AB, :, length(V̂s), size(V̂, 2)), (2, 1, 3)) # B x n_pdfs x N + sums = sum(Ẑ, dims = 2) # B x 1 x N Ẑ = broadcast!(/, Ẑ, Ẑ, sums) ttl = dropdims(minimum(sums, dims = (2, 3)), dims = (2, 3)) (exp ∘ val).(Ẑ[:, 1:end-1, 1:end-1]), val.(ttl) diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl new file mode 100644 index 0000000..7b492ff --- /dev/null +++ b/src/lbp_inference.jl @@ -0,0 +1,230 @@ +struct CompiledFSM{K<:Semiring} + α̂ + T̂ + T̂ᵀ + Ĉ + Ĉᵀ +end +compile(fsm::FSM, Ĉ::AbstractMatrix) = CompiledFSM{eltype(fsm.α̂)}(fsm.α̂, fsm.T̂, copy(fsm.T̂'), Ĉ, copy(Ĉ')) +nstates(fsm::CompiledFSM) = length(fsm.α̂) - 1 + +function Adapt.adapt_structure(::Type{<:CuArray}, cfsm::CompiledFSM{K}) where K + T̂ = CuSparseMatrixCSC(cfsm.T̂) + T̂ᵀ = CuSparseMatrixCSC(cfsm.T̂ᵀ) + Ĉ = CuSparseMatrixCSC(cfsm.Ĉ) + Ĉᵀ = CuSparseMatrixCSC(cfsm.Ĉᵀ) + CompiledFSM{K}( + CuSparseVector(cfsm.α̂), + CuSparseMatrixCSR(T̂ᵀ.colPtr, T̂ᵀ.rowVal, T̂ᵀ.nzVal, T̂.dims), + CuSparseMatrixCSR(T̂.colPtr, T̂.rowVal, T̂.nzVal, T̂ᵀ.dims), + CuSparseMatrixCSR(Ĉᵀ.colPtr, Ĉᵀ.rowVal, Ĉᵀ.nzVal, Ĉ.dims), + CuSparseMatrixCSR(Ĉ.colPtr, Ĉ.rowVal, Ĉ.nzVal, Ĉᵀ.dims), + ) +end + +struct FactorialFSM{K<:Semiring} + fsms::Vector{FSM{K}} +end + +struct CompiledFactorialFSM{K<:Semiring, J} + fsms::Vector{CompiledFSM{K}} + Ĉ::AbstractSparseMatrix{K} + Ĉᵀ::AbstractSparseMatrix{K} +end + +function compile(ffsm::FactorialFSM{K}, Ĉs::AbstractMatrix{K}...) where K + J = length(ffsm.fsms) + @assert length(Ĉs) == J "Number of state maps `Ĉs` has to be same as number of FSMs in FactorialFSM" + Ĉ = foldl(joint_smap, sparse.(Ĉs)) + CompiledFactorialFSM{K, J}( + map(zip(ffsm.fsms, Ĉs)) do (fsm, smap) compile(fsm, smap) end, + #[compile(fsm, smap) for (fsm, smap) in zip(ffsm.fsms, Ĉs)], + Ĉ, + copy(Ĉ') + ) +end + +Base.getindex(cffsm::CompiledFactorialFSM{K}, key::Integer) where K = cffsm.fsms[key] +nfsms(cffsm::CompiledFactorialFSM{K, J}) where {K, J} = J + +function Adapt.adapt_structure(T::Type{<:CuArray}, cffsm::CompiledFactorialFSM{K}) where K + fsms = adapt_structure.(T, cffsm.fsms) + Ĉ = CuSparseMatrixCSC(cffsm.Ĉ) + Ĉᵀ = CuSparseMatrixCSC(cffsm.Ĉᵀ) + CompiledFactorialFSM( + fsms, + CuSparseMatrixCSR(Ĉᵀ.colPtr, Ĉᵀ.rowVal, Ĉᵀ.nzVal, Ĉ.dims), + CuSparseMatrixCSR(Ĉ.colPtr, Ĉ.rowVal, Ĉ.nzVal, Ĉᵀ.dims), + ) +end + +function batch(cffsm1::CompiledFactorialFSM{K, J}, cffsms::CompiledFactorialFSM{K, J}...) where {K, J} + fsms = Vector{CompiledFSM{K}}[] + for j in 1:J + push!(fsms, batch(cffsm1.fsms[j], map(cffsm -> cffsm.fsms[j], cffsms)...)) + end + CompiledFactorialFSM{K, J}( + fsms, + blockdiag(cffsm1.Ĉ, map(cffsm -> cffsm.Ĉ, cffsms)...), + blockdiag(cffsm1.Ĉᵀ, map(cffsm -> cffsm.Ĉᵀ, cffsms)...) + ) +end + +function joint_smap(smap1::AbstractSparseMatrix{K}, smap2::AbstractSparseMatrix{K}) where K + S1, P1 = size(smap1) + S2, P2 = size(smap2) + @assert P1 == P2 + + I, J, V = Int[], Int[], K[] + for (i1, j1, v2) in zip(findnz(smap1)...) + for (i2, j2, v2) in zip(findnz(smap2)...) + push!(I, (i2-1) * S1 + i1) + push!(J, (j2-1) * P1 + j1) + push!(V, one(K)) + end + end + sparse(I, J, V, S1*S2, P1*P2) +end + +function init_messages(ffsm::CompiledFactorialFSM{K, J}, N::Integer) where {K,J} + S = [nstates(fsm) + 1 for fsm in ffsm.fsms] # + virtual state + + m1 = [Array{K}(undef, s, N) for s in S] + m2 = [Array{K}(undef, s, N) for s in S] + m3 = [Array{K}(undef, s, N) for s in S] + + @views for j in 1:length(S) + fsm = ffsm[j] + S = nstates(fsm) + m1j, m2j, m3j = m1[j], m2[j], m3[j] + fill!(m2j[:, 2:end], one(K) / K(S)) + m2j[:, 1] = fsm.α̂ + fill!(m3j, one(K) / K(S)) + end + + (m1 = m1, m2 = m2, m3 = m3) +end + +function compare_msgs(old_mgs, new_msg; eps=1.2) + n_spkrs = length(old_mgs[:m1]) + changes = [] + for n_msg in 1:3 + om, nm = old_mgs[Symbol("m", n_msg)], new_msg[Symbol("m", n_msg)] + + for j in 1:n_spkrs + diff = isapprox.(val.(om[j]), val.(nm[j]); atol=eps) + diff_N = findall(vec(.!all(diff, dims=1))) + n_diff_N = length(diff_N) + append!(changes, zip(repeat([j], n_diff_N), repeat([n_msg], n_diff_N), diff_N)) + end + end + changes +end + +""" + pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring} + +Compute pdf posteriors given loglikehood `llhs` and FactorialFSM `ffsm`. +Likehoods `llhs` are represented as 2D matrix of shape +`(S1*S2, N)`, where `S1` is number of states in 1st FSM and `S2` is number +of states in 2nd FSM in `ffsm`. + +args: + ffsm - FactorialFSM od 2 FSMs with states `S1` and `S2` + V̂s - B-element vector of expanded loglikehoods with size `(S1*S2, N)` (see `expand`) +""" +function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, V̂s::Vector{<:AbstractMatrix{K}}; eps=1.2, min_iter=3, max_iter=10) where {K<:Semiring, J} + B = length(V̂s) + llhs = vcat(V̂s...) + N = size(llhs, 2) + n_spkrs = nfsms(ffsm) + states_per_fsm = [nstates(ffsm[i]) + 1 for i in 1:n_spkrs] # + virtual state + total_states = reduce(*, states_per_fsm) + # we assume that llhs is already expanded (see `expand`) + state_llhs = ffsm.Ĉ * llhs + @assert size(state_llhs, 1) == total_states "size(state_llhs): $(size(state_llhs, 1)) != $total_states" + state_llhs = reshape(state_llhs, vcat(states_per_fsm, [N])...) # S1 x S2 x ... x N + + messages = init_messages(ffsm, N) + iter = 0 + changes = nothing + + while true + iter += 1 + new_messages = lbp_step!(deepcopy(messages), ffsm, state_llhs) + # check the difference between messages + changes = compare_msgs(messages, new_messages; eps=eps) + messages = new_messages + # if all messages are same then break + if (isempty(changes) && iter >= min_iter) || iter >= max_iter + break + end + end + total_num_msgs = [size.(m, 2) for m in messages] |> sum |> sum + println("Finished in iter: $iter (max: $max_iter)") + println("Number of changed msgs before max_iter was reached: $(length(changes)) ($total_num_msgs)") + + m1, m2, m3 = messages + result = [] + ttl = zeros(K, B) + for (j, (m1j, m2j, m3j)) in enumerate(zip(m1, m2, m3)) + state_marginals = broadcast!(*, m1j, m1j, m2j, m3j) + pdf_marginals = ffsm[j].Ĉᵀ * state_marginals # B*P x N + pdf_marginals = permutedims(reshape(pdf_marginals, :, B, N), (2, 1, 3)) # B x P x N + sums = sum(pdf_marginals, dims=2) # B x 1 x N + broadcast!(/, pdf_marginals, pdf_marginals, sums) + broadcast!(+, ttl, ttl, dropdims(minimum(sums, dims=(2,3)), dims=(2,3))) + push!(result, (exp ∘ val).(pdf_marginals[:, 1:end-1, 1:end-1])) + end + + # TODO: return pdf marginals of the same size as llhs + result, val.(ttl) +end + +function lbp_step!(messages, ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractArray{K, D}) where {K, J, D} + n_spkrs = ndims(llhs) - 1 + @assert n_spkrs == 2 "Currently we do not support more than 2 speakers!" + + m1, m2, m3 = messages + N = size(llhs, ndims(llhs)) + + for j in 1:n_spkrs + # this spkr's messages + m1j, m2j, m3j = m1[j], m2[j], m3[j] + fsm = ffsm[j] + T̂, T̂ᵀ = fsm.T̂, fsm.T̂ᵀ + + # other spkr's messages + k = n_spkrs - j + 1 + m2k, m3k = m2[k], m3[k] + buffer_k = similar(m2k[:, 1]) + llhs_perm = permutedims(llhs, [j, k, 3]) + + @views for n in 1:N + broadcast!(*, buffer_k, m2k[:, n], m3k[:, n]) + broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) # Do NOT normalize + # (w/o norm it will converge faster) + mul!(m1j[:, n], llhs_perm[:, :, n], buffer_k) + #broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) + #sum!(m1j[:, n], buffer') + end + + m2j[:, 1] = fsm.α̂ + buffer = similar(m1j[:, 1]) # NOT OPTIMAL, m2j should be used instead + # but we have issue with Julia 1.7 -> 1.8 fixed it, but introduce another bugs + # check https://github.com/JuliaSparse/SparseArrays.jl/issues/251 + + @views for n in 2:N + broadcast!(*, buffer, m1j[:, n - 1], m2j[:, n - 1]) + mul!(m2j[:, n], T̂ᵀ, buffer) + end + + @views fill!(m3j[:, N], one(K)) + @views for n in N-1:-1:1 + broadcast!(*, buffer, m1j[:, n + 1], m3j[:, n + 1]) + mul!(m3j[:, n], T̂, buffer) + end + end + + return (m1=m1, m2=m2, m3=m3) +end diff --git a/test/runtests.jl b/test/runtests.jl index 935c5ff..1f5374b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ using Semirings using SparseArrays using Test +import MarkovModels: lbp_step!, joint_smap + @testset verbose=true "FSMs" begin include("test_fsms.jl") end @@ -22,6 +24,10 @@ else @warn "CUDA is not functional skipping tests." end +@testset verbose=true "LBP Inference" begin + include("test_lbp_inference.jl") +end + #@testset verbose=true "algorithms" begin # include("test_algorithms.jl") #end diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl new file mode 100644 index 0000000..7815a8b --- /dev/null +++ b/test/test_lbp_inference.jl @@ -0,0 +1,147 @@ +const N = 100 # number of frames +const S1 = 30 # number of states for spkr1 +const S2 = 33 # number of states for spkr2 +const T = Float32 +const SF = LogSemiring{T} + +#normalize2(x) = x ./ sum(x) +normalize2(x) = x + +function naive_lbp_step(messages, ffsm::CompiledFactorialFSM{T}, llhs::AbstractArray{T, 3}) where T + N = size(llhs, 3) + n_spkrs = 2 + m1, m2, m3 = deepcopy(messages) + + for j in 1:n_spkrs + + @views for n in 1:N + buffer = llhs[:, :, n] + + for k in 1:n_spkrs + k == j && continue + if k == 2 + buffer = permutedims(buffer, [2, 1]) + buffer = buffer .* normalize2(m2[k][:, n] .* m3[k][:, n]) + buffer = permutedims(buffer, [2, 1]) + elseif k == 1 + buffer = buffer .* normalize2(m2[k][:, n] .* m3[k][:, n]) + else + throw(ErrorException("Not available for more then 2 spkrs")) + end + end + m1[j][:, n] = sum(buffer, dims=[k for k in 1:n_spkrs if k != j]) + end + + fsm = ffsm[j] + m2[j][:, 1] = fsm.α̂ + @views for n in 2:N + m2[j][:, n] = ((m2[j][:, n - 1] .* m1[j][:, n - 1])' * fsm.T̂)' + end + + @views fill!(m3[j][:, N], one(T)) + @views for n in N-1:-1:1 + m3[j][:, n] = fsm.T̂ * (m3[j][:, n + 1] .* m1[j][:, n + 1]) + end + end + + return (m1=m1, m2=m2, m3=m3) +end + +make_lin_ffsm(SF, T, num_states_per_fsm...) = begin + fsms = FSM{SF}[] + smaps = AbstractSparseMatrix{SF}[] + P = maximum(num_states_per_fsm) + for S in num_states_per_fsm + α = sparse(vcat(one(T), zeros(T, S-2))) + T̂ = sparse(Bidiagonal([T(0.75) for _ in 1:S-1], [T(0.25) for _ in 1:S-2], :U)) + ω = sparse(vcat(zeros(T, S-2), [T(0.25)])) + labels = collect(1:S-1) + push!( + fsms, + FSM( + convert.(SF, log.(α)), + convert.(SF, log.(T̂)), + convert.(SF, log.(ω)), + labels + ) |> renorm + ) + push!( + smaps, + diagm(abs(P-S) => ones(SF, S))[1:S, :] |> sparse + ) + end + compile(FactorialFSM(fsms), smaps...) +end + +make_ffsm(SF, T, num_states_per_fsm...) = begin + fsms = FSM{SF}[] + smaps = AbstractSparseMatrix{SF}[] + P = maximum(num_states_per_fsm) + for S in num_states_per_fsm + α = sprand(T, S-1, 0.25) + T̂ = sprand(T, S-1, S-1, 0.95) + ω = sprand(T, S-1, 0.75) + labels = collect(1:S-1) + push!( + fsms, + FSM( + convert.(SF, log.(α)), + convert.(SF, log.(T̂)), + convert.(SF, log.(ω)), + labels + ) |> renorm + ) + push!( + smaps, + diagm(abs(P-S) => ones(SF, S))[1:S, :] |> sparse + ) + end + compile(FactorialFSM(fsms), smaps...) +end + +false_print(msg) = begin + println(msg) + println("") + false +end + +@testset "Linear Factorial FSM" begin + # Linear FSM + lffsm = make_lin_ffsm(SF, T, S1, S2) + llhs = convert.(SF, log.(rand(T, S1, S2, N))) + + m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) / SF(S) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(lffsm.fsms, m2)] + m3 = [ones(SF, S, N) / SF(S) for S in [S1, S2]] + + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], lffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), lffsm, llhs) + + for j in 1:2 + @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) + @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=true)) + @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=true)) + #@test all(ref_m2[j] .≈ hyp_m2[j]) || false_print((println.(ref_m2[j] .≈ hyp_m2[j]), println.(ref_m2[j]), println(""), println.( hyp_m2[j]))) + end +end + +@testset "random FactorialFSM" begin + ffsm = make_ffsm(SF, T, S1, S2) + llhs = convert.(SF, log.(rand(T, S1, S2, N))) + + m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(ffsm.fsms, m2)] + m3 = [ones(SF, S, N) for S in [S1, S2]] + + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], ffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), ffsm, llhs) + + for j in 1:2 + @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=false)) + @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=false)) + @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=false)) + #@test all(ref_m2[j] .≈ hyp_m2[j]) || false_print((println.(ref_m2[j] .≈ hyp_m2[j]), println.(ref_m2[j]), println(""), println.( hyp_m2[j]))) + end +end