From 67bdfc38978a60b41ba4e3ab34f3d9c212c66a7e Mon Sep 17 00:00:00 2001 From: Kocour Martin Date: Thu, 28 Jul 2022 15:52:38 +0200 Subject: [PATCH 1/8] Minor fix --- src/inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference.jl b/src/inference.jl index 8e497cc..d19924e 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT """ - expand(V::AbstractMatrix{K}, seqlength = size(lhs, 2)) where K + expand(V::AbstractMatrix{K}, seqlength = size(V, 2)) where K Expand the ``D x N`` matrix of likelihoods `V` to a ``D+1 x N+1`` matrix `V̂`. This function is to prepare the matrix of likelihood to From 48f9563de31578919e4cbcf8698184c47cdf3916 Mon Sep 17 00:00:00 2001 From: Kocour Martin Date: Mon, 5 Sep 2022 09:04:03 +0200 Subject: [PATCH 2/8] Update LBP messages --- src/lbp_inference.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/lbp_inference.jl diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl new file mode 100644 index 0000000..8739224 --- /dev/null +++ b/src/lbp_inference.jl @@ -0,0 +1,46 @@ +struct FactorialFSM{K<:Semiring} + fsms::Vector{FSM{K}} + smaps::Vector{AbstractMatrix{K}} +end + +function FactorialFSM( + fsm1::FSM{K}, smap1::AbstractMatrix{K}, + fsm2::FSM{K}, smap2::AbstractMatrix{K} +) where K + FactorialFSM([fsm1, fsm2], [smap1, smap2]) +end + +function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K + dims = [nstates(fsm) for fsm in ffsm.fsms] + nspkrs = length(dims) + + m1 = [Array{K}(undef, dims[j], N) for j in nspkrs] + m2 = [Array{K}(undef, dims[j], N) for j in nspkrs] + m3 = [Array{K}(undef, dims[j], N) for j in nspkrs] + + @views for j in 1:nspkrs + S = dims(j) # n_states in fsm for spkr j + m1j, m2j, m3j = m1[j], m2[j], m3[j] + fill!(m2[:, 2:end], one(K) / S) + m2[1, :] = + end + + (m1 = m1, m2 = m2, m3 = m3) +end + +function lbp_posteriors(ffsm::FactorialFSM{K}, llhs::AbstractArray{K, J}; eps=1e-4) where {K, J} + messages = init_messages(ffsm, llhs) + + while true + new_messages = lbp_step(messages) + # check the difference between messages + diffs = [.≈(new_m, m; atol=eps) for (new_m, m) in zip(new_messages, messages)] + # if all messages are same then break + all(all.(diffs)) && break + old_messages = new_messages + end +end + +function lbp_step(messages) + +end From 6e164215cc7edaacfe8ec7133e23ebc430cde6d5 Mon Sep 17 00:00:00 2001 From: Kocour Martin Date: Mon, 5 Sep 2022 09:13:03 +0200 Subject: [PATCH 3/8] Init messages --- src/lbp_inference.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index 8739224..d09f81c 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -19,10 +19,13 @@ function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K m3 = [Array{K}(undef, dims[j], N) for j in nspkrs] @views for j in 1:nspkrs - S = dims(j) # n_states in fsm for spkr j + fsm = ffsm.fsms[j] + S = nstates(fsm) m1j, m2j, m3j = m1[j], m2[j], m3[j] - fill!(m2[:, 2:end], one(K) / S) - m2[1, :] = + fill!(m2[:, 2:end], one(K)) + fill!(m2[:, 1], fsm.α) + fill!(m3[:, 1:end-1], one(K)) + fill!(m3[:, end], one(K)) end (m1 = m1, m2 = m2, m3 = m3) From 54d8980c702f761dc20fb79478a0b273f2bd0075 Mon Sep 17 00:00:00 2001 From: Martin Kocour Date: Wed, 7 Sep 2022 11:53:48 +0200 Subject: [PATCH 4/8] LBP step + tests --- src/MarkovModels.jl | 7 +- src/lbp_inference.jl | 57 +++++++++++++--- test/runtests.jl | 6 ++ test/test_lbp_inference.jl | 136 +++++++++++++++++++++++++++++++++++++ 4 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 test/test_lbp_inference.jl diff --git a/src/MarkovModels.jl b/src/MarkovModels.jl index b4bbc7d..b1635f6 100644 --- a/src/MarkovModels.jl +++ b/src/MarkovModels.jl @@ -40,7 +40,11 @@ export expand, αrecursion, βrecursion, - pdfposteriors + pdfposteriors, + + # LBP + FactorialFSM, + lbp_posteriors include("utils.jl") include("fsm.jl") @@ -49,6 +53,7 @@ include("algorithms.jl") include("lmfsm.jl") include("linalg.jl") include("inference.jl") +include("lbp_inference.jl") #export maxstateposteriors diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index d09f81c..e8ece5a 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -22,28 +22,69 @@ function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K fsm = ffsm.fsms[j] S = nstates(fsm) m1j, m2j, m3j = m1[j], m2[j], m3[j] - fill!(m2[:, 2:end], one(K)) - fill!(m2[:, 1], fsm.α) - fill!(m3[:, 1:end-1], one(K)) - fill!(m3[:, end], one(K)) + fill!(m2j[:, 2:end], one(K)) + fill!(m2j[:, 1], fsm.α) + fill!(m3j, one(K)) end (m1 = m1, m2 = m2, m3 = m3) end function lbp_posteriors(ffsm::FactorialFSM{K}, llhs::AbstractArray{K, J}; eps=1e-4) where {K, J} - messages = init_messages(ffsm, llhs) + N = size(llhs, ndims(llhs)) + messages = init_messages(ffsm, N) while true - new_messages = lbp_step(messages) + new_messages = lbp_step!(deepcopy(messages), ffsm, llhs) # check the difference between messages diffs = [.≈(new_m, m; atol=eps) for (new_m, m) in zip(new_messages, messages)] # if all messages are same then break all(all.(diffs)) && break - old_messages = new_messages + messages = new_messages end + + m1, m2, m3 = messages end -function lbp_step(messages) +function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) where K + n_spkrs = ndims(llhs) - 1 + @assert n_spkrs == 2 "Currently we do not support more than 2 speakers!" + + m1, m2, m3 = messages + N = size(llhs, ndims(llhs)) + + for j in 1:n_spkrs + # this spkr's messages + m1j, m2j, m3j = m1[j], m2[j], m3[j] + fsm = ffsm.fsms[j] + T̂, T̂ᵀ = fsm.T̂, fsm.T̂' # TODO maybe not optimal + + # other spkr's messages + k = n_spkrs - j + 1 + m2k, m3k = m2[k], m3[k] + buffer_k = similar(m2k[:, 1]) + llhs_perm = permutedims(llhs, [k, j, 3]) + buffer = similar(llhs_perm[:, :, 1]) + + @views for n in 1:N + broadcast!(*, buffer_k, m2k[:, n], m3k[:, n]) + broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) + sum!(m1j[:, n], buffer') + end + + buffer = similar(m1j[:, 1]) + m2j[:, 1] = fsm.α̂ + @views for n in 2:N + broadcast!(*, buffer, m1j[:, n - 1], m2j[:, n - 1]) + mul!(m2j[:, n], T̂ᵀ, buffer) + end + + @views fill!(m3j[:, N], one(K)) + @views for n in N-1:-1:1 + broadcast!(*, buffer, m1j[:, n + 1], m3j[:, n + 1]) + mul!(m3j[:, n], T̂, buffer) + end + end + return (m1=m1, m2=m2, m3=m3) end diff --git a/test/runtests.jl b/test/runtests.jl index 935c5ff..908bc6a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ using Semirings using SparseArrays using Test +import MarkovModels: lbp_step! + @testset verbose=true "FSMs" begin include("test_fsms.jl") end @@ -22,6 +24,10 @@ else @warn "CUDA is not functional skipping tests." end +@testset verbose=true "LBP Inference" begin + include("test_lbp_inference.jl") +end + #@testset verbose=true "algorithms" begin # include("test_algorithms.jl") #end diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl new file mode 100644 index 0000000..5a4b980 --- /dev/null +++ b/test/test_lbp_inference.jl @@ -0,0 +1,136 @@ +const N = 100 # number of frames +const S1 = 30 # number of states for spkr1 +const S2 = 33 # number of states for spkr2 +const T = Float32 +const SF = LogSemiring{T} + +function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, 3}) where T + N = size(llhs, 3) + n_spkrs = 2 + m1, m2, m3 = deepcopy(messages) + + for j in 1:n_spkrs + + for n in 1:N + buffer = llhs[:, :, n] + + for k in 1:n_spkrs + k == j && continue + if k == 2 + buffer = permutedims(buffer, [2, 1]) + buffer = buffer .* (m2[k][:, n] .* m3[k][:, n]) + buffer = permutedims(buffer, [2, 1]) + elseif k == 1 + buffer = buffer .* (m2[k][:, n] .* m3[k][:, n]) + else + throw(ErrorException("Not available for more then 2 spkrs")) + end + end + m1[j][:, n] = sum(buffer, dims=[k for k in 1:n_spkrs if k != j]) + end + + fsm = ffsm.fsms[j] + m2[j][:, 1] = fsm.α̂ + for n in 2:N + m2[j][:, n] = ((m2[j][:, n - 1] .* m1[j][:, n - 1])' * fsm.T̂)' + end + + fill!(m3[j][:, 1], one(T)) + for n in N-1:-1:1 + m3[j][:, n] = fsm.T̂ * (m3[j][:, n + 1] .* m1[j][:, n + 1]) + end + end + + return (m1=m1, m2=m2, m3=m3) +end + +make_lin_ffsm(SF, T, num_states_per_fsm...) = begin + fsms = FSM{SF}[] + smaps = AbstractMatrix{SF}[] + for S in num_states_per_fsm + α = sparse(vcat(one(T), zeros(T, S-2))) + T̂ = sparse(Bidiagonal([T(0.75) for _ in 1:S-1], [T(0.25) for _ in 1:S-2], :U)) + ω = sparse(vcat(zeros(T, S-2), [T(0.25)])) + labels = collect(1:S-1) + push!( + fsms, + FSM( + convert.(SF, log.(α)), + convert.(SF, log.(T̂)), + convert.(SF, log.(ω)), + labels + ) |> renorm + ) + push!(smaps, ones(S,S)) + end + FactorialFSM(fsms, smaps) +end + +make_ffsm(SF, T, num_states_per_fsm...) = begin + fsms = FSM{SF}[] + smaps = AbstractMatrix{SF}[] + for S in num_states_per_fsm + α = sprand(T, S-1, 0.25) + T̂ = sprand(T, S-1, S-1, 0.95) + ω = sprand(T, S-1, 0.75) + labels = collect(1:S-1) + push!( + fsms, + FSM( + convert.(SF, log.(α)), + convert.(SF, log.(T̂)), + convert.(SF, log.(ω)), + labels + ) |> renorm + ) + push!(smaps, ones(S,S)) + end + FactorialFSM(fsms, smaps) +end + +false_print(msg) = begin + println(msg) + println("") + false +end + +@testset "random FactorialFSM" begin + ffsm = make_ffsm(SF, T, S1, S2) + llhs = convert.(SF, log.(rand(T, S1, S2, N))) + + m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(ffsm.fsms, m2)] + m3 = [ones(SF, S, N) for S in [S1, S2]] + + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], ffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), ffsm, llhs) + + for j in 1:2 + @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) + @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=true)) + @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=true)) + #@test all(ref_m2[j] .≈ hyp_m2[j]) || false_print((println.(ref_m2[j] .≈ hyp_m2[j]), println.(ref_m2[j]), println(""), println.( hyp_m2[j]))) + end +end + +@testset "Linear Factorial FSM" begin + # Linear FSM + lffsm = make_lin_ffsm(SF, T, S1, S2) + llhs = convert.(SF, log.(rand(T, S1, S2, N))) + + m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) / SF(S) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(lffsm.fsms, m2)] + m3 = [ones(SF, S, N) / SF(S) for S in [S1, S2]] + + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], lffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), lffsm, llhs) + + for j in 1:2 + @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) + @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=true)) + @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=true)) + #@test all(ref_m2[j] .≈ hyp_m2[j]) || false_print((println.(ref_m2[j] .≈ hyp_m2[j]), println.(ref_m2[j]), println(""), println.( hyp_m2[j]))) + end +end From efc1e05da7cbc3084751333f06cbd2fd914d2705 Mon Sep 17 00:00:00 2001 From: Martin Kocour Date: Wed, 7 Sep 2022 17:50:15 +0200 Subject: [PATCH 5/8] Improve the test suite --- src/lbp_inference.jl | 5 +++-- test/test_lbp_inference.jl | 38 +++++++++++++++++++------------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index e8ece5a..c0025e1 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -57,7 +57,7 @@ function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) w # this spkr's messages m1j, m2j, m3j = m1[j], m2[j], m3[j] fsm = ffsm.fsms[j] - T̂, T̂ᵀ = fsm.T̂, fsm.T̂' # TODO maybe not optimal + T̂, T̂ᵀ = fsm.T̂, permutedims(fsm.T̂, [2,1]) # TODO maybe not optimal # other spkr's messages k = n_spkrs - j + 1 @@ -72,8 +72,9 @@ function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) w sum!(m1j[:, n], buffer') end - buffer = similar(m1j[:, 1]) m2j[:, 1] = fsm.α̂ + buffer = similar(m1j[:, 1]) # NOT OPTIMAL, m2j should be used instead but we have issue with Julia 1.7 -> 1.8 fixed it, but introduce another bugs + # check https://github.com/JuliaSparse/SparseArrays.jl/issues/251 @views for n in 2:N broadcast!(*, buffer, m1j[:, n - 1], m2j[:, n - 1]) mul!(m2j[:, n], T̂ᵀ, buffer) diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl index 5a4b980..f355fff 100644 --- a/test/test_lbp_inference.jl +++ b/test/test_lbp_inference.jl @@ -11,7 +11,7 @@ function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, for j in 1:n_spkrs - for n in 1:N + @views for n in 1:N buffer = llhs[:, :, n] for k in 1:n_spkrs @@ -31,12 +31,12 @@ function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, fsm = ffsm.fsms[j] m2[j][:, 1] = fsm.α̂ - for n in 2:N + @views for n in 2:N m2[j][:, n] = ((m2[j][:, n - 1] .* m1[j][:, n - 1])' * fsm.T̂)' end - fill!(m3[j][:, 1], one(T)) - for n in N-1:-1:1 + @views fill!(m3[j][:, N], one(T)) + @views for n in N-1:-1:1 m3[j][:, n] = fsm.T̂ * (m3[j][:, n + 1] .* m1[j][:, n + 1]) end end @@ -94,17 +94,18 @@ false_print(msg) = begin false end -@testset "random FactorialFSM" begin - ffsm = make_ffsm(SF, T, S1, S2) +@testset "Linear Factorial FSM" begin + # Linear FSM + lffsm = make_lin_ffsm(SF, T, S1, S2) llhs = convert.(SF, log.(rand(T, S1, S2, N))) m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] - m2 = [ones(SF, S, N-1) for S in [S1, S2]] - m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(ffsm.fsms, m2)] - m3 = [ones(SF, S, N) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) / SF(S) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(lffsm.fsms, m2)] + m3 = [ones(SF, S, N) / SF(S) for S in [S1, S2]] - ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], ffsm, llhs) - hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), ffsm, llhs) + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], lffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), lffsm, llhs) for j in 1:2 @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) @@ -114,18 +115,17 @@ end end end -@testset "Linear Factorial FSM" begin - # Linear FSM - lffsm = make_lin_ffsm(SF, T, S1, S2) +@testset "random FactorialFSM" begin + ffsm = make_ffsm(SF, T, S1, S2) llhs = convert.(SF, log.(rand(T, S1, S2, N))) m1 = [Array{SF}(undef, S, N) for S in [S1, S2]] - m2 = [ones(SF, S, N-1) / SF(S) for S in [S1, S2]] - m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(lffsm.fsms, m2)] - m3 = [ones(SF, S, N) / SF(S) for S in [S1, S2]] + m2 = [ones(SF, S, N-1) for S in [S1, S2]] + m2 = [hcat(fsm.α̂, m) for (fsm, m) in zip(ffsm.fsms, m2)] + m3 = [ones(SF, S, N) for S in [S1, S2]] - ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], lffsm, llhs) - hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), lffsm, llhs) + ref_m1, ref_m2, ref_m3 = naive_lbp_step([m1,m2,m3], ffsm, llhs) + hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), ffsm, llhs) for j in 1:2 @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) From ea54c7ca5d896a299fab3b4fbe97210ed46fc0f2 Mon Sep 17 00:00:00 2001 From: Martin Kocour Date: Thu, 8 Sep 2022 19:23:50 +0200 Subject: [PATCH 6/8] pdfposteriors for non-batched FSMs --- src/inference.jl | 24 +++++-- src/lbp_inference.jl | 130 +++++++++++++++++++++++++++++-------- test/runtests.jl | 2 +- test/test_lbp_inference.jl | 19 ++++-- 4 files changed, 135 insertions(+), 40 deletions(-) diff --git a/src/inference.jl b/src/inference.jl index d19924e..9369688 100644 --- a/src/inference.jl +++ b/src/inference.jl @@ -58,17 +58,27 @@ function βrecursion(T̂::AbstractMatrix{K}, lhs::AbstractMatrix{K}) where K B end +""" + pdfposteriors(fsm::FSM{K}, V̂s, Ĉs) where K + +Compute pdf posteriors. + +args: + fsm - batched FSM + V̂s - vector of pdf likelihoods of shape (n_pdfs x N) + Ĉs - vector of state-pdf mappings +""" function pdfposteriors(fsm::FSM{K}, V̂s, Ĉs) where K - V̂ = vcat(V̂s...) + V̂ = vcat(V̂s...) # B*n_pdfs x N V̂k = copyto!(similar(V̂, K), V̂) - Ĉ = blockdiag(Ĉs...) - ĈV̂ = (Ĉ * V̂k) + Ĉ = blockdiag(Ĉs...) # B*n_states x B*n_pdfs + ĈV̂ = (Ĉ * V̂k) # B*n_states x N state_A = αrecursion(fsm.α̂, fsm.T̂', ĈV̂) state_B = βrecursion(fsm.T̂, ĈV̂) - state_AB = broadcast!(*, state_A, state_A, state_B) - AB = Ĉ' * state_AB - Ẑ = permutedims(reshape(AB, :, length(V̂s), size(V̂, 2)), (2, 1, 3)) - sums = sum(Ẑ, dims = 2) + state_AB = broadcast!(*, state_A, state_A, state_B) # B*n_states x N + AB = Ĉ' * state_AB # B*n_pdfs x N + Ẑ = permutedims(reshape(AB, :, length(V̂s), size(V̂, 2)), (2, 1, 3)) # B x n_pdfs x N + sums = sum(Ẑ, dims = 2) # B x 1 x N Ẑ = broadcast!(/, Ẑ, Ẑ, sums) ttl = dropdims(minimum(sums, dims = (2, 3)), dims = (2, 3)) (exp ∘ val).(Ẑ[:, 1:end-1, 1:end-1]), val.(ttl) diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index c0025e1..2072714 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -1,49 +1,124 @@ struct FactorialFSM{K<:Semiring} fsms::Vector{FSM{K}} - smaps::Vector{AbstractMatrix{K}} + smap::Vector{AbstractSparseMatrix{K}} end function FactorialFSM( - fsm1::FSM{K}, smap1::AbstractMatrix{K}, - fsm2::FSM{K}, smap2::AbstractMatrix{K} -) where K + fsm1::FSM{K}, smap1::AbstractSparseMatrix{K}, + fsm2::FSM{K}, smap2::AbstractSparseMatrix{K} +) where K FactorialFSM([fsm1, fsm2], [smap1, smap2]) end +nfsms(ffsm::FactorialFSM{K}) where K = length(ffsm.fsms) +getindex(ffsm::FactorialFSM{K}, key::Integer) where K = ffsm.fsms[key] + +function joint_smap(smap1::AbstractSparseMatrix{K}, smap2::AbstractSparseMatrix{K}) where K + S1, P1 = size(smap1) + S2, P2 = size(smap2) + @assert P1 == P2 + + I, J, V = [], [], [] + for (i1, j1, v2) in zip(findnz(smap1)) + for (i2, j2, v2) in zip(findnz(smap2)) + push!(I, (i2-1) * S1 + i1) + push!(J, (j2-1) * P1 + j1) + push!(V, one(K)) + end + end + sparse(I, J, V, S1*S2, P1*P2) +end + function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K - dims = [nstates(fsm) for fsm in ffsm.fsms] - nspkrs = length(dims) + S = [nstates(fsm) + 1 for fsm in ffsm.fsms] # + virtual state - m1 = [Array{K}(undef, dims[j], N) for j in nspkrs] - m2 = [Array{K}(undef, dims[j], N) for j in nspkrs] - m3 = [Array{K}(undef, dims[j], N) for j in nspkrs] + m1 = [Array{K}(undef, s, N) for s in S] + m2 = [Array{K}(undef, s, N) for s in S] + m3 = [Array{K}(undef, s, N) for s in S] - @views for j in 1:nspkrs - fsm = ffsm.fsms[j] + @views for j in 1:length(S) + fsm = ffsm[j] S = nstates(fsm) m1j, m2j, m3j = m1[j], m2[j], m3[j] - fill!(m2j[:, 2:end], one(K)) - fill!(m2j[:, 1], fsm.α) - fill!(m3j, one(K)) + fill!(m2j[:, 2:end], one(K) / K(S)) + m2j[:, 1] = fsm.α̂ + fill!(m3j, one(K) / K(S)) end (m1 = m1, m2 = m2, m3 = m3) end -function lbp_posteriors(ffsm::FactorialFSM{K}, llhs::AbstractArray{K, J}; eps=1e-4) where {K, J} - N = size(llhs, ndims(llhs)) +function compare_msgs(old_mgs, new_msg; eps=1e-4) + n_spkrs = length(old_mgs[:m1]) + changes = [] + for n_msg in 1:3 + om, nm = old_mgs[Symbol("m", n_msg)], new_msg[Symbol("m", n_msg)] + + for j in 1:n_spkrs + diff = isapprox.(val.(om[j]), val.(nm[j]); atol=eps) + diff_N = findall(vec(.!all(diff, dims=1))) + n_diff_N = length(diff_N) + append!(changes, zip(repeat([j], n_diff_N), repeat([n_msg], n_diff_N), diff_N)) + end + end + changes +end + +""" + pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring} + +Compute pdf posteriors given loglikehood `llhs` and FactorialFSM `ffsm`. +Likehoods `llhs` are represented as 2D matrix of shape +`(S1*S2, N)`, where `S1` is number of states in 1st FSM and `S2` is number +of states in 2nd FSM in `ffsm`. + +args: + ffsm - FactorialFSM od 2 FSMs with states `S1` and `S2` + llhs - loglikehoods with size `(S1*S2, N)` +""" +function pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring} + n_spkrs = nfsms(ffsm) + states_per_fsm = [nstates(ffsm[i]) + 1 for i in 1:n_spkrs] # + virtual state + N = size(llhs, 2) + total_states = reduce(*, states_per_fsm) + # we assume that llhs is already expanded (see `expand`) + state_llhs = joint_smap(ffsm.smap...) * llhs + @assert size(state_llhs, 1) == total_states + state_llhs = reshape(state_llhs, vcat(states_per_fsm, [N])...) # S1 x S2 x ... x N + messages = init_messages(ffsm, N) + iter = 0 + changes = nothing while true - new_messages = lbp_step!(deepcopy(messages), ffsm, llhs) + iter += 1 + new_messages = lbp_step!(deepcopy(messages), ffsm, state_llhs) # check the difference between messages - diffs = [.≈(new_m, m; atol=eps) for (new_m, m) in zip(new_messages, messages)] - # if all messages are same then break - all(all.(diffs)) && break + changes = compare_msgs(messages, new_messages; eps=eps) messages = new_messages + # if all messages are same then break + if isempty(changes) || iter >= max_iter + break + end end + total_num_msgs = [size.(m, 2) for m in messages] |> sum |> sum + println("Finished in iter: $iter (max: $max_iter)") + println("Number of changed msgs before max_iter was reached: $(length(changes)) ($total_num_msgs)") m1, m2, m3 = messages + result = [] + ttl = zero(K) + for (j, (m1j, m2j, m3j)) in enumerate(zip(m1, m2, m3)) + state_marginals = broadcast!(*, m1j, m1j, m2j, m3j) + pdf_marginals = ffsm.smap[j]' * state_marginals + sums = sum(pdf_marginals, dims=1) # 1 x N + broadcast!(/, pdf_marginals, pdf_marginals, sums) + ttl += minimum(sums) + push!(result, pdf_marginals) + end + + # TODO: return pdf marginals of the same size as llhs + result, ttl end function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) where K @@ -56,25 +131,28 @@ function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) w for j in 1:n_spkrs # this spkr's messages m1j, m2j, m3j = m1[j], m2[j], m3[j] - fsm = ffsm.fsms[j] + fsm = ffsm[j] T̂, T̂ᵀ = fsm.T̂, permutedims(fsm.T̂, [2,1]) # TODO maybe not optimal # other spkr's messages k = n_spkrs - j + 1 m2k, m3k = m2[k], m3[k] buffer_k = similar(m2k[:, 1]) - llhs_perm = permutedims(llhs, [k, j, 3]) - buffer = similar(llhs_perm[:, :, 1]) + llhs_perm = permutedims(llhs, [j, k, 3]) @views for n in 1:N broadcast!(*, buffer_k, m2k[:, n], m3k[:, n]) - broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) - sum!(m1j[:, n], buffer') + broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) + mul!(m1j[:, n], llhs_perm[:, :, n], buffer_k) + #broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) + #sum!(m1j[:, n], buffer') end m2j[:, 1] = fsm.α̂ - buffer = similar(m1j[:, 1]) # NOT OPTIMAL, m2j should be used instead but we have issue with Julia 1.7 -> 1.8 fixed it, but introduce another bugs + buffer = similar(m1j[:, 1]) # NOT OPTIMAL, m2j should be used instead + # but we have issue with Julia 1.7 -> 1.8 fixed it, but introduce another bugs # check https://github.com/JuliaSparse/SparseArrays.jl/issues/251 + @views for n in 2:N broadcast!(*, buffer, m1j[:, n - 1], m2j[:, n - 1]) mul!(m2j[:, n], T̂ᵀ, buffer) diff --git a/test/runtests.jl b/test/runtests.jl index 908bc6a..1f5374b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using Semirings using SparseArrays using Test -import MarkovModels: lbp_step! +import MarkovModels: lbp_step!, joint_smap @testset verbose=true "FSMs" begin include("test_fsms.jl") diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl index f355fff..8b5cf88 100644 --- a/test/test_lbp_inference.jl +++ b/test/test_lbp_inference.jl @@ -4,6 +4,9 @@ const S2 = 33 # number of states for spkr2 const T = Float32 const SF = LogSemiring{T} +#normalize2(x) = x ./ sum(x) +normalize2(x) = x + function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, 3}) where T N = size(llhs, 3) n_spkrs = 2 @@ -18,10 +21,10 @@ function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, k == j && continue if k == 2 buffer = permutedims(buffer, [2, 1]) - buffer = buffer .* (m2[k][:, n] .* m3[k][:, n]) + buffer = buffer .* normalize2(m2[k][:, n] .* m3[k][:, n]) buffer = permutedims(buffer, [2, 1]) elseif k == 1 - buffer = buffer .* (m2[k][:, n] .* m3[k][:, n]) + buffer = buffer .* normalize2(m2[k][:, n] .* m3[k][:, n]) else throw(ErrorException("Not available for more then 2 spkrs")) end @@ -46,7 +49,8 @@ end make_lin_ffsm(SF, T, num_states_per_fsm...) = begin fsms = FSM{SF}[] - smaps = AbstractMatrix{SF}[] + smaps = AbstractSparseMatrix{SF}[] + P = maximum(num_states_per_fsm) for S in num_states_per_fsm α = sparse(vcat(one(T), zeros(T, S-2))) T̂ = sparse(Bidiagonal([T(0.75) for _ in 1:S-1], [T(0.25) for _ in 1:S-2], :U)) @@ -61,14 +65,17 @@ make_lin_ffsm(SF, T, num_states_per_fsm...) = begin labels ) |> renorm ) - push!(smaps, ones(S,S)) + push!( + smaps, + diagm(abs(P-S) => ones(SF, S))[1:S, :] |> sparse + ) end FactorialFSM(fsms, smaps) end make_ffsm(SF, T, num_states_per_fsm...) = begin fsms = FSM{SF}[] - smaps = AbstractMatrix{SF}[] + smaps = AbstractSparseMatrix{SF}[] for S in num_states_per_fsm α = sprand(T, S-1, 0.25) T̂ = sprand(T, S-1, S-1, 0.95) @@ -83,7 +90,7 @@ make_ffsm(SF, T, num_states_per_fsm...) = begin labels ) |> renorm ) - push!(smaps, ones(S,S)) + push!(smaps, diagm(ones(SF, S)) |> sparse) end FactorialFSM(fsms, smaps) end From 5f333a032608cba165c2f2facde609d80073bbcb Mon Sep 17 00:00:00 2001 From: Martin Kocour Date: Thu, 8 Sep 2022 21:00:17 +0200 Subject: [PATCH 7/8] CompiledFactorialFSM has blockdiagonal matrices --- src/MarkovModels.jl | 4 +- src/lbp_inference.jl | 95 ++++++++++++++++++++++++++++++-------- test/test_lbp_inference.jl | 14 ++++-- 3 files changed, 88 insertions(+), 25 deletions(-) diff --git a/src/MarkovModels.jl b/src/MarkovModels.jl index b1635f6..b12cc85 100644 --- a/src/MarkovModels.jl +++ b/src/MarkovModels.jl @@ -38,13 +38,15 @@ export # Inference expand, + compile, αrecursion, βrecursion, pdfposteriors, + CompiledFSM, # LBP FactorialFSM, - lbp_posteriors + CompiledFactorialFSM include("utils.jl") include("fsm.jl") diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index 2072714..e2b0823 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -1,26 +1,82 @@ +struct CompiledFSM{K<:Semiring} + α̂ + T̂ + T̂ᵀ + Ĉ + Ĉᵀ +end +compile(fsm::FSM, Ĉ::AbstractMatrix) = CompiledFSM{eltype(fsm.α̂)}(fsm.α̂, fsm.T̂, copy(fsm.T̂'), Ĉ, copy(Ĉ')) + +function Adapt.adapt_structure(::Type{<:CuArray}, cfsm::CompiledFSM{K}) where K + T̂ = CuSparseMatrixCSC(cfsm.T̂) + T̂ᵀ = CuSparseMatrixCSC(cfsm.T̂ᵀ) + Ĉ = CuSparseMatrixCSC(cfsm.Ĉ) + Ĉᵀ = CuSparseMatrixCSC(cfsm.Ĉᵀ) + CompiledFSM{K}( + CuSparseVector(cfsm.α̂), + CuSparseMatrixCSR(T̂ᵀ.colPtr, T̂ᵀ.rowVal, T̂ᵀ.nzVal, T̂.dims), + CuSparseMatrixCSR(T̂.colPtr, T̂.rowVal, T̂.nzVal, T̂ᵀ.dims), + CuSparseMatrixCSR(Ĉᵀ.colPtr, Ĉᵀ.rowVal, Ĉᵀ.nzVal, Ĉ.dims), + CuSparseMatrixCSR(Ĉ.colPtr, Ĉ.rowVal, Ĉ.nzVal, Ĉᵀ.dims), + ) +end + struct FactorialFSM{K<:Semiring} fsms::Vector{FSM{K}} - smap::Vector{AbstractSparseMatrix{K}} end -function FactorialFSM( - fsm1::FSM{K}, smap1::AbstractSparseMatrix{K}, - fsm2::FSM{K}, smap2::AbstractSparseMatrix{K} -) where K - FactorialFSM([fsm1, fsm2], [smap1, smap2]) +struct CompiledFactorialFSM{K<:Semiring, J} + fsms::Vector{CompiledFSM{K}} + Ĉ::AbstractSparseMatrix{K} + Ĉᵀ::AbstractSparseMatrix{K} +end + +function compile(ffsm::FactorialFSM{K}, Ĉs::AbstractMatrix{K}...) where K + J = length(ffsm.fsms) + @assert length(Ĉs) == J "Number of state maps `Ĉs` has to be same as number of FSMs in FactorialFSM" + Ĉ = foldl(joint_smap, sparse.(Ĉs)) + CompiledFactorialFSM{K, J}( + map(zip(ffsm.fsms, Ĉs)) do (fsm, smap) compile(fsm, smap) end, + #[compile(fsm, smap) for (fsm, smap) in zip(ffsm.fsms, Ĉs)], + Ĉ, + copy(Ĉ') + ) +end + +Base.getindex(cffsm::CompiledFactorialFSM{K}, key::Integer) where K = cffsm.fsms[key] +nfsms(cffsm::CompiledFactorialFSM{K, J}) where {K, J} = J + +function Adapt.adapt_structure(T::Type{<:CuArray}, cffsm::CompiledFactorialFSM{K}) where K + fsms = adapt_structure.(T, cffsm.fsms) + Ĉ = CuSparseMatrixCSC(cffsm.Ĉ) + Ĉᵀ = CuSparseMatrixCSC(cffsm.Ĉᵀ) + CompiledFactorialFSM( + fsms, + CuSparseMatrixCSR(Ĉᵀ.colPtr, Ĉᵀ.rowVal, Ĉᵀ.nzVal, Ĉ.dims), + CuSparseMatrixCSR(Ĉ.colPtr, Ĉ.rowVal, Ĉ.nzVal, Ĉᵀ.dims), + ) end -nfsms(ffsm::FactorialFSM{K}) where K = length(ffsm.fsms) -getindex(ffsm::FactorialFSM{K}, key::Integer) where K = ffsm.fsms[key] +function batch(cffsm1::CompiledFactorialFSM{K, J}, cffsms::CompiledFactorialFSM{K, J}...) where {K, J} + fsms = Vector{CompiledFSM{K}}[] + for j in 1:J + push!(fsms, batch(cffsm1.fsms[j], map(cffsm -> cffsm.fsms[j], cffsms)...)) + end + CompiledFactorialFSM{K, J}( + fsms, + blockdiag(cffsm1.Ĉ, map(cffsm -> cffsm.Ĉ, cffsms)...), + blockdiag(cffsm1.Ĉᵀ, map(cffsm -> cffsm.Ĉᵀ, cffsms)...) + ) +end function joint_smap(smap1::AbstractSparseMatrix{K}, smap2::AbstractSparseMatrix{K}) where K S1, P1 = size(smap1) S2, P2 = size(smap2) @assert P1 == P2 - I, J, V = [], [], [] - for (i1, j1, v2) in zip(findnz(smap1)) - for (i2, j2, v2) in zip(findnz(smap2)) + I, J, V = Int[], Int[], K[] + for (i1, j1, v2) in zip(findnz(smap1)...) + for (i2, j2, v2) in zip(findnz(smap2)...) push!(I, (i2-1) * S1 + i1) push!(J, (j2-1) * P1 + j1) push!(V, one(K)) @@ -39,7 +95,7 @@ function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K @views for j in 1:length(S) fsm = ffsm[j] S = nstates(fsm) - m1j, m2j, m3j = m1[j], m2[j], m3[j] + m1j, m2j, m3j = m1[j], m2[j], m3[j] fill!(m2j[:, 2:end], one(K) / K(S)) m2j[:, 1] = fsm.α̂ fill!(m3j, one(K) / K(S)) @@ -76,13 +132,13 @@ args: ffsm - FactorialFSM od 2 FSMs with states `S1` and `S2` llhs - loglikehoods with size `(S1*S2, N)` """ -function pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring} +function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring, J} + N = size(llhs, 2) n_spkrs = nfsms(ffsm) states_per_fsm = [nstates(ffsm[i]) + 1 for i in 1:n_spkrs] # + virtual state - N = size(llhs, 2) total_states = reduce(*, states_per_fsm) # we assume that llhs is already expanded (see `expand`) - state_llhs = joint_smap(ffsm.smap...) * llhs + state_llhs = ffsm.Ĉ * llhs @assert size(state_llhs, 1) == total_states state_llhs = reshape(state_llhs, vcat(states_per_fsm, [N])...) # S1 x S2 x ... x N @@ -110,7 +166,7 @@ function pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, ttl = zero(K) for (j, (m1j, m2j, m3j)) in enumerate(zip(m1, m2, m3)) state_marginals = broadcast!(*, m1j, m1j, m2j, m3j) - pdf_marginals = ffsm.smap[j]' * state_marginals + pdf_marginals = ffsm[j].Ĉᵀ * state_marginals sums = sum(pdf_marginals, dims=1) # 1 x N broadcast!(/, pdf_marginals, pdf_marginals, sums) ttl += minimum(sums) @@ -121,7 +177,7 @@ function pdfposteriors(ffsm::FactorialFSM{K}, llhs::AbstractMatrix{K}; eps=1e-4, result, ttl end -function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) where K +function lbp_step!(messages, ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractArray{K, D}) where {K, J, D} n_spkrs = ndims(llhs) - 1 @assert n_spkrs == 2 "Currently we do not support more than 2 speakers!" @@ -132,7 +188,7 @@ function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) w # this spkr's messages m1j, m2j, m3j = m1[j], m2[j], m3[j] fsm = ffsm[j] - T̂, T̂ᵀ = fsm.T̂, permutedims(fsm.T̂, [2,1]) # TODO maybe not optimal + T̂, T̂ᵀ = fsm.T̂, fsm.T̂ᵀ # other spkr's messages k = n_spkrs - j + 1 @@ -142,7 +198,8 @@ function lbp_step!(messages, ffsm::FactorialFSM{K}, llhs::AbstractArray{K, 3}) w @views for n in 1:N broadcast!(*, buffer_k, m2k[:, n], m3k[:, n]) - broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) + #broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) # Do NOT normalize + # (w/o norm it will converge faster) mul!(m1j[:, n], llhs_perm[:, :, n], buffer_k) #broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) #sum!(m1j[:, n], buffer') diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl index 8b5cf88..fec48aa 100644 --- a/test/test_lbp_inference.jl +++ b/test/test_lbp_inference.jl @@ -7,7 +7,7 @@ const SF = LogSemiring{T} #normalize2(x) = x ./ sum(x) normalize2(x) = x -function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, 3}) where T +function naive_lbp_step(messages, ffsm::CompiledFactorialFSM{T}, llhs::AbstractArray{T, 3}) where T N = size(llhs, 3) n_spkrs = 2 m1, m2, m3 = deepcopy(messages) @@ -32,7 +32,7 @@ function naive_lbp_step(messages, ffsm::FactorialFSM{T}, llhs::AbstractArray{T, m1[j][:, n] = sum(buffer, dims=[k for k in 1:n_spkrs if k != j]) end - fsm = ffsm.fsms[j] + fsm = ffsm[j] m2[j][:, 1] = fsm.α̂ @views for n in 2:N m2[j][:, n] = ((m2[j][:, n - 1] .* m1[j][:, n - 1])' * fsm.T̂)' @@ -70,12 +70,13 @@ make_lin_ffsm(SF, T, num_states_per_fsm...) = begin diagm(abs(P-S) => ones(SF, S))[1:S, :] |> sparse ) end - FactorialFSM(fsms, smaps) + compile(FactorialFSM(fsms), smaps...) end make_ffsm(SF, T, num_states_per_fsm...) = begin fsms = FSM{SF}[] smaps = AbstractSparseMatrix{SF}[] + P = maximum(num_states_per_fsm) for S in num_states_per_fsm α = sprand(T, S-1, 0.25) T̂ = sprand(T, S-1, S-1, 0.95) @@ -90,9 +91,12 @@ make_ffsm(SF, T, num_states_per_fsm...) = begin labels ) |> renorm ) - push!(smaps, diagm(ones(SF, S)) |> sparse) + push!( + smaps, + diagm(abs(P-S) => ones(SF, S))[1:S, :] |> sparse + ) end - FactorialFSM(fsms, smaps) + compile(FactorialFSM(fsms), smaps...) end false_print(msg) = begin From 3910afe60823952c52e166425e395bd7e07f20ee Mon Sep 17 00:00:00 2001 From: Martin Kocour Date: Thu, 8 Sep 2022 22:07:55 +0200 Subject: [PATCH 8/8] LBP infererence for FactorialFSM --- src/lbp_inference.jl | 30 +++++++++++++++++------------- test/test_lbp_inference.jl | 6 +++--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/lbp_inference.jl b/src/lbp_inference.jl index e2b0823..7b492ff 100644 --- a/src/lbp_inference.jl +++ b/src/lbp_inference.jl @@ -6,6 +6,7 @@ struct CompiledFSM{K<:Semiring} Ĉᵀ end compile(fsm::FSM, Ĉ::AbstractMatrix) = CompiledFSM{eltype(fsm.α̂)}(fsm.α̂, fsm.T̂, copy(fsm.T̂'), Ĉ, copy(Ĉ')) +nstates(fsm::CompiledFSM) = length(fsm.α̂) - 1 function Adapt.adapt_structure(::Type{<:CuArray}, cfsm::CompiledFSM{K}) where K T̂ = CuSparseMatrixCSC(cfsm.T̂) @@ -85,7 +86,7 @@ function joint_smap(smap1::AbstractSparseMatrix{K}, smap2::AbstractSparseMatrix{ sparse(I, J, V, S1*S2, P1*P2) end -function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K +function init_messages(ffsm::CompiledFactorialFSM{K, J}, N::Integer) where {K,J} S = [nstates(fsm) + 1 for fsm in ffsm.fsms] # + virtual state m1 = [Array{K}(undef, s, N) for s in S] @@ -104,7 +105,7 @@ function init_messages(ffsm::FactorialFSM{K}, N::Integer) where K (m1 = m1, m2 = m2, m3 = m3) end -function compare_msgs(old_mgs, new_msg; eps=1e-4) +function compare_msgs(old_mgs, new_msg; eps=1.2) n_spkrs = length(old_mgs[:m1]) changes = [] for n_msg in 1:3 @@ -130,16 +131,18 @@ of states in 2nd FSM in `ffsm`. args: ffsm - FactorialFSM od 2 FSMs with states `S1` and `S2` - llhs - loglikehoods with size `(S1*S2, N)` + V̂s - B-element vector of expanded loglikehoods with size `(S1*S2, N)` (see `expand`) """ -function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractMatrix{K}; eps=1e-4, max_iter=10) where {K<:Semiring, J} +function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, V̂s::Vector{<:AbstractMatrix{K}}; eps=1.2, min_iter=3, max_iter=10) where {K<:Semiring, J} + B = length(V̂s) + llhs = vcat(V̂s...) N = size(llhs, 2) n_spkrs = nfsms(ffsm) states_per_fsm = [nstates(ffsm[i]) + 1 for i in 1:n_spkrs] # + virtual state total_states = reduce(*, states_per_fsm) # we assume that llhs is already expanded (see `expand`) state_llhs = ffsm.Ĉ * llhs - @assert size(state_llhs, 1) == total_states + @assert size(state_llhs, 1) == total_states "size(state_llhs): $(size(state_llhs, 1)) != $total_states" state_llhs = reshape(state_llhs, vcat(states_per_fsm, [N])...) # S1 x S2 x ... x N messages = init_messages(ffsm, N) @@ -153,7 +156,7 @@ function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractMatrix{K} changes = compare_msgs(messages, new_messages; eps=eps) messages = new_messages # if all messages are same then break - if isempty(changes) || iter >= max_iter + if (isempty(changes) && iter >= min_iter) || iter >= max_iter break end end @@ -163,18 +166,19 @@ function pdfposteriors(ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractMatrix{K} m1, m2, m3 = messages result = [] - ttl = zero(K) + ttl = zeros(K, B) for (j, (m1j, m2j, m3j)) in enumerate(zip(m1, m2, m3)) state_marginals = broadcast!(*, m1j, m1j, m2j, m3j) - pdf_marginals = ffsm[j].Ĉᵀ * state_marginals - sums = sum(pdf_marginals, dims=1) # 1 x N + pdf_marginals = ffsm[j].Ĉᵀ * state_marginals # B*P x N + pdf_marginals = permutedims(reshape(pdf_marginals, :, B, N), (2, 1, 3)) # B x P x N + sums = sum(pdf_marginals, dims=2) # B x 1 x N broadcast!(/, pdf_marginals, pdf_marginals, sums) - ttl += minimum(sums) - push!(result, pdf_marginals) + broadcast!(+, ttl, ttl, dropdims(minimum(sums, dims=(2,3)), dims=(2,3))) + push!(result, (exp ∘ val).(pdf_marginals[:, 1:end-1, 1:end-1])) end # TODO: return pdf marginals of the same size as llhs - result, ttl + result, val.(ttl) end function lbp_step!(messages, ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractArray{K, D}) where {K, J, D} @@ -198,7 +202,7 @@ function lbp_step!(messages, ffsm::CompiledFactorialFSM{K, J}, llhs::AbstractArr @views for n in 1:N broadcast!(*, buffer_k, m2k[:, n], m3k[:, n]) - #broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) # Do NOT normalize + broadcast!(/, buffer_k, buffer_k, sum(buffer_k)) # Do NOT normalize # (w/o norm it will converge faster) mul!(m1j[:, n], llhs_perm[:, :, n], buffer_k) #broadcast!(*, buffer, llhs_perm[:, :, n], buffer_k) diff --git a/test/test_lbp_inference.jl b/test/test_lbp_inference.jl index fec48aa..7815a8b 100644 --- a/test/test_lbp_inference.jl +++ b/test/test_lbp_inference.jl @@ -139,9 +139,9 @@ end hyp_m1, hyp_m2, hyp_m3 = lbp_step!(deepcopy((m1=m1, m2=m2, m3=m3)), ffsm, llhs) for j in 1:2 - @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=true)) - @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=true)) - @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=true)) + @test all(isapprox.(val.(ref_m1[j]), val.(hyp_m1[j]), nans=false)) + @test all(isapprox.(val.(ref_m2[j]), val.(hyp_m2[j]), nans=false)) + @test all(isapprox.(val.(ref_m3[j]), val.(hyp_m3[j]), nans=false)) #@test all(ref_m2[j] .≈ hyp_m2[j]) || false_print((println.(ref_m2[j] .≈ hyp_m2[j]), println.(ref_m2[j]), println(""), println.( hyp_m2[j]))) end end