Skip to content

Commit eb39e15

Browse files
committed
EM for MRTVCModel
1 parent e603914 commit eb39e15

5 files changed

Lines changed: 95 additions & 43 deletions

File tree

docs/src/advanced.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ For the EM algorithm, the updates in each iteration are
1919
\boldsymbol{\Gamma}_i^{(t + 1)} &= \frac{1}{r_i} \boldsymbol{\Gamma}_i^{(t)} ( \boldsymbol{R}^{(t)T} \boldsymbol{V}_i \boldsymbol{R}^{(t)} - \boldsymbol{M}_i^{(t)} ) \boldsymbol{\Gamma}_i^{(t)} + \boldsymbol{\Gamma}_i^{(t)},
2020
\end{aligned}
2121
```
22-
where ``r_i = \text{rank}(\boldsymbol{V}_i)``. As seen, the updates for mean effects ``\boldsymbol{B}`` are the same for these two algorithms.
22+
where ``r_i = \text{rank}(\boldsymbol{V}_i)``. As seen, the updates for mean effects ``\boldsymbol{B}`` are the same for MM and EM algorithms.
2323

2424
# Inference
2525
Standard errors for our estimates are calculated using the Fisher information matrix, where
@@ -42,7 +42,7 @@ In the setting of missing response, the adjusted MM updates in each interation a
4242
\boldsymbol{\Gamma}_i^{(t + 1)} &= \boldsymbol{L}_i^{-(t)T}[\boldsymbol{L}_i^{(t)T}\boldsymbol{\Gamma}_i^{(t)}(\boldsymbol{R}^{*(t)T}\boldsymbol{V}_i\boldsymbol{R}^{*(t)} + \boldsymbol{M}_i^{*(t)})\boldsymbol{\Gamma}_i^{(t)}\boldsymbol{L}_i^{(t)}]^{1/2} \boldsymbol{L}_i^{-(t)},
4343
\end{aligned}
4444
```
45-
where ``\boldsymbol{Z}^{(t)}`` is the completed response matrix from conditional mean and ``\boldsymbol{M}_i^{*(t)} = (\boldsymbol{I}_d \otimes \boldsymbol{1}_n)^T [(\boldsymbol{1}_d \boldsymbol{1}_d^T \otimes \boldsymbol{V}_i) \odot (\boldsymbol{\Omega}^{-(t)} \boldsymbol{P}^T \boldsymbol{C}^{(t)}\boldsymbol{P}\boldsymbol{\Omega}^{-(t)})] (\boldsymbol{I}_d \otimes \boldsymbol{1}_n)``, while ``\boldsymbol{R}^{*(t)}`` is the ``n \times d`` matrix such that ``\text{vec}\ \boldsymbol{R}^{*(t)} = \boldsymbol{\Omega}^{-(t)} \text{vec}(\boldsymbol{Z}^{(t)} - \boldsymbol{X} \boldsymbol{B}^{(t)})``. Additionally, ``\boldsymbol{P}`` is the ``nd \times nd`` permutation matrix such that ``\boldsymbol{P} \cdot \text{vec}\ \boldsymbol{Y} = \begin{bmatrix} \boldsymbol{y}_{\text{obs}} \\ \boldsymbol{y}_{\text{mis}} \end{bmatrix}``, where ``\boldsymbol{y}_{\text{obs}}`` and ``\boldsymbol{y}_{\text{mis}}`` are vectors of observed and missing response values, respectively, in column-major order, and the block matrix ``\boldsymbol{C}^{(t)}`` is ``\boldsymbol{0}`` except for a lower-right block consisting of conditional variance. As seen, the two MM updates are of similar form.
45+
where ``\boldsymbol{Z}^{(t)}`` is the completed response matrix from conditional mean and ``\boldsymbol{M}_i^{*(t)} = (\boldsymbol{I}_d \otimes \boldsymbol{1}_n)^T [(\boldsymbol{1}_d \boldsymbol{1}_d^T \otimes \boldsymbol{V}_i) \odot (\boldsymbol{\Omega}^{-(t)} \boldsymbol{P}^T \boldsymbol{C}^{(t)}\boldsymbol{P}\boldsymbol{\Omega}^{-(t)})] (\boldsymbol{I}_d \otimes \boldsymbol{1}_n)``, while ``\boldsymbol{R}^{*(t)}`` is the ``n \times d`` matrix such that ``\text{vec}\ \boldsymbol{R}^{*(t)} = \boldsymbol{\Omega}^{-(t)} \text{vec}(\boldsymbol{Z}^{(t)} - \boldsymbol{X} \boldsymbol{B}^{(t)})``. Additionally, ``\boldsymbol{P}`` is the ``nd \times nd`` permutation matrix such that ``\boldsymbol{P} \cdot \text{vec}\ \boldsymbol{Y} = \begin{bmatrix} \boldsymbol{y}_{\text{obs}} \\ \boldsymbol{y}_{\text{mis}} \end{bmatrix}``, where ``\boldsymbol{y}_{\text{obs}}`` and ``\boldsymbol{y}_{\text{mis}}`` are vectors of observed and missing response values, respectively, in column-major order, and the block matrix ``\boldsymbol{C}^{(t)}`` is ``\boldsymbol{0}`` except for a lower-right block consisting of conditional variance. As seen, the MM updates are of similar form to the non-missing response case.
4646

4747
# Special case: ``m = 2``
4848
When there are ``m = 2`` variance components such that ``\boldsymbol{\Omega} = \boldsymbol{\Gamma}_1 \otimes \boldsymbol{V}_1 + \boldsymbol{\Gamma}_2 \otimes \boldsymbol{V}_2``, repeated inversion of the ``nd \times nd`` matrix ``\boldsymbol{\Omega}`` per iteration can be avoided and reduced to one ``d \times d`` generalized eigen-decomposition per iteration. Without loss of generality, if we assume ``\boldsymbol{V}_2`` to be positive definite, the generalized eigen-decomposition of the matrix pair ``(\boldsymbol{V}_1, \boldsymbol{V}_2)`` yields generalized eigenvalues ``\boldsymbol{d} = (d_1, \dots, d_n)^T`` and generalized eigenvectors ``\boldsymbol{U}`` such that ``\boldsymbol{U}^T \boldsymbol{V}_1 \boldsymbol{U} = \boldsymbol{D} = \text{diag}(\boldsymbol{d})`` and ``\boldsymbol{U}^T \boldsymbol{V}_2 \boldsymbol{U} = \boldsymbol{I}_n``. Similarly, if we let the generalized eigen-decomposition of ``(\boldsymbol{\Gamma}_1^{(t)}, \boldsymbol{\Gamma}_2^{(t)})`` be ``(\boldsymbol{\Lambda}^{(t)}, \boldsymbol{\Phi}^{(t)})`` such that ``\boldsymbol{\Phi}^{(t)T} \boldsymbol{\Gamma}_1^{(t)} \boldsymbol{\Phi}^{(t)} = \boldsymbol{\Lambda}^{(t)} = \text{diag}(\boldsymbol{\lambda^{(t)}})`` and ``\boldsymbol{\Phi}^{(t)T} \boldsymbol{\Gamma}_2^{(t)} \boldsymbol{\Phi}^{(t)} = \boldsymbol{I}_d``, then the MM updates in each iteration become
@@ -57,7 +57,12 @@ When there are ``m = 2`` variance components such that ``\boldsymbol{\Omega} = \
5757

5858
where ``\tilde{\boldsymbol{X}} = \boldsymbol{U}^T \boldsymbol{X}``, ``\tilde{\boldsymbol{Y}} = \boldsymbol{U}^T \boldsymbol{Y}``, ``\boldsymbol{L}_1^{(t)}`` is the Cholesky factor of ``\boldsymbol{\Phi}^{(t)}\text{diag}(\text{tr}(\boldsymbol{D}(\lambda_k^{(t)}\boldsymbol{D} + \boldsymbol{I}_n)^{-1}), k = 1,\dots, d)\boldsymbol{\Phi}^{(t)T}``, ``\boldsymbol{L}_2^{(t)}`` is the Cholesky factor of ``\boldsymbol{\Phi}^{(t)}\text{diag}(\text{tr}((\lambda_k^{(t)}\boldsymbol{D} + \boldsymbol{I}_n)^{-1}), k = 1,\dots, d)\boldsymbol{\Phi}^{(t)T}``, ``\boldsymbol{N}_1^{(t)} = \boldsymbol{D}^{1/2}\{[(\tilde{\boldsymbol{Y}} - \tilde{\boldsymbol{X}}\boldsymbol{B})\boldsymbol{\Phi}^{(t)}]\oslash(\boldsymbol{d}\boldsymbol{\lambda}^{(t)T} + \boldsymbol{1}_n\boldsymbol{1}_d^T) \} \boldsymbol{\Lambda}^{(t)}\boldsymbol{\Phi}^{-(t)}``, and ``\boldsymbol{N}_2^{(t)} = \{[(\tilde{\boldsymbol{Y}} - \tilde{\boldsymbol{X}}\boldsymbol{B})\boldsymbol{\Phi}^{(t)}]\oslash(\boldsymbol{d}\boldsymbol{\lambda}^{(t)T} + \boldsymbol{1}_n\boldsymbol{1}_d^T) \} \boldsymbol{\Phi}^{-(t)}``. ``\oslash`` denotes the Hadamard quotient.
5959

60-
In this setting, the Fisher information matrix is equivalent to
60+
For the sake of completeness, we note that the EM updates become
61+
```math
62+
\boldsymbol{\Gamma}_i^{(t + 1)} = \frac{1}{r_i} ( \boldsymbol{N}_i^{(t)T} \boldsymbol{N}_i^{(t)} - \boldsymbol{\Gamma}_i^{(t)} \boldsymbol{L}_i^{(t)}\boldsymbol{L}_i^{(t)T} \boldsymbol{\Gamma}_i^{(t)} ) + \boldsymbol{\Gamma}_i^{(t)}.
63+
```
64+
65+
Finally, in this setting, the Fisher information matrix is equivalent to
6166
```math
6267
\begin{aligned}
6368
\text{E} \left[- \frac{\partial^2}{\partial(\text{vec}\ \boldsymbol{B})^T \partial(\text{vec}\ \boldsymbol{B})} \mathcal{L} \right] &= (\boldsymbol{\Phi}^{T}\otimes \tilde{\boldsymbol{X}})^T (\boldsymbol{\Lambda} \otimes \boldsymbol{D} + \boldsymbol{I}_d \otimes \boldsymbol{I}_n)^{-1} (\boldsymbol{\Phi}^{T}\otimes \tilde{\boldsymbol{X}}) \\

src/MultiResponseVarianceComponentModels.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ struct MRTVCModel{T <: BlasReal} <: VCModel
318318
Φ :: Matrix{T}
319319
Λ :: Vector{T}
320320
logdetΣ2 :: Vector{T}
321+
V_rank :: Vector{Int}
321322
# working arrays
322323
xtx :: Matrix{T} # Gram matrix X'X
323324
xty :: Matrix{T} # X'Y
@@ -450,6 +451,7 @@ function MRTVCModel(
450451
Φ = Matrix{T}(undef, d, d)
451452
Λ = Vector{T}(undef, d)
452453
logdetΣ2 = zeros(T, 1)
454+
V_rank = [rank(V[k]) for k in 1:m]
453455
# working arrays
454456
xtx = transpose(Xmat) * Xmat
455457
xty = transpose(Xmat) * Y
@@ -471,7 +473,7 @@ function MRTVCModel(
471473
logl = zeros(T, 1)
472474
MRTVCModel{T}(
473475
Y, Ỹ, Xmat, X̃, V, U, D, logdetV2,
474-
B, Σ, Φ, Λ, logdetΣ2,
476+
B, Σ, Φ, Λ, logdetΣ2, V_rank,
475477
xtx, xty, ỸΦ, R̃, R̃Φ, N1tN1, N2tN2,
476478
storage_d_1, storage_d_2, storage_d_d_1, storage_d_d_2,
477479
storage_p_p, storage_pd, storage_pd_pd,

src/eigen.jl

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ function fit!(
1111
# dimensions
1212
n, d, p, m = size(Y, 1), size(Y, 2), size(X, 2), length(V)
1313
if model.reml
14-
@info "Running $algo algorithm for REML estimation"
14+
@info "Running $algo algorithm with generalized eigen-decomposition for REML estimation"
1515
else
16-
@info "Running $algo algorithm for ML estimation"
16+
@info "Running $algo algorithm with generalized eigen-decomposition for ML estimation"
1717
end
1818
# record iterate history if requested
1919
history = ConvergenceHistory(partial = !log)
@@ -69,7 +69,7 @@ function fit!(
6969
update_B!(model)
7070
update_res!(model)
7171
end
72-
update_Σ!(model)
72+
update_Σ!(model; algo = algo)
7373
update_Φ!(model)
7474
mul!(model.R̃Φ, model.R̃, model.Φ)
7575
logl_prev = logl
@@ -107,6 +107,7 @@ end
107107

108108
function update_Σ!(
109109
model :: MRTVCModel{T};
110+
algo :: Symbol = :MM
110111
) where T <: BlasReal
111112
n, d = size(model.Ỹ, 1), size(model.Ỹ, 2)
112113
fill!(model.storage_d_1, zero(T))
@@ -132,49 +133,68 @@ function update_Σ!(
132133
end
133134
mul!(model.N1tN1, transpose(model.R̃Φ), model.R̃Φ)
134135
Φinv = inv(model.Φ)
135-
# update Σ[1]
136-
lmul!(Diagonal(model.storage_d_1), model.N1tN1)
137-
rmul!(model.N1tN1, Diagonal(model.storage_d_1))
138-
vals, vecs = eigen!(Symmetric(model.N1tN1))
139-
@inbounds for j in 1:length(vals)
140-
if vals[j] > 0
141-
v = sqrt(sqrt(vals[j]))
142-
for i in 1:size(vecs, 1)
143-
vecs[i, j] *= v
144-
end
145-
else
146-
for i in 1:size(vecs, 1)
147-
vecs[i, j] = 0
136+
if algo == :MM
137+
# update Σ[1]
138+
lmul!(Diagonal(model.storage_d_1), model.N1tN1)
139+
rmul!(model.N1tN1, Diagonal(model.storage_d_1))
140+
vals, vecs = eigen!(Symmetric(model.N1tN1))
141+
@inbounds for j in 1:length(vals)
142+
if vals[j] > 0
143+
v = sqrt(sqrt(vals[j]))
144+
for i in 1:size(vecs, 1)
145+
vecs[i, j] *= v
146+
end
147+
else
148+
for i in 1:size(vecs, 1)
149+
vecs[i, j] = 0
150+
end
148151
end
149152
end
150-
end
151-
lmul!(Diagonal(one(T) ./ model.storage_d_1), vecs)
152-
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
153-
mul!(model.Σ[1], model.storage_d_d_1, transpose(model.storage_d_d_1))
154-
# update Σ[2]
155-
lmul!(Diagonal(model.storage_d_2), model.N2tN2)
156-
rmul!(model.N2tN2, Diagonal(model.storage_d_2))
157-
vals, vecs = eigen!(Symmetric(model.N2tN2))
158-
@inbounds for j in 1:length(vals)
159-
if vals[j] > 0
160-
v = sqrt(sqrt(vals[j]))
161-
for i in 1:size(vecs, 1)
162-
vecs[i, j] *= v
163-
end
164-
else
165-
for i in 1:size(vecs, 1)
166-
vecs[i, j] = 0
153+
lmul!(Diagonal(one(T) ./ model.storage_d_1), vecs)
154+
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
155+
mul!(model.Σ[1], model.storage_d_d_1, transpose(model.storage_d_d_1))
156+
# update Σ[2]
157+
lmul!(Diagonal(model.storage_d_2), model.N2tN2)
158+
rmul!(model.N2tN2, Diagonal(model.storage_d_2))
159+
vals, vecs = eigen!(Symmetric(model.N2tN2))
160+
@inbounds for j in 1:length(vals)
161+
if vals[j] > 0
162+
v = sqrt(sqrt(vals[j]))
163+
for i in 1:size(vecs, 1)
164+
vecs[i, j] *= v
165+
end
166+
else
167+
for i in 1:size(vecs, 1)
168+
vecs[i, j] = 0
169+
end
167170
end
168171
end
172+
lmul!(Diagonal(one(T) ./ model.storage_d_2), vecs)
173+
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
174+
mul!(model.Σ[2], model.storage_d_d_1, transpose(model.storage_d_d_1))
175+
model.Σ
176+
elseif algo == :EM
177+
# update Σ[1]
178+
@inbounds for j in 1:d
179+
λj = model.Λ[j]
180+
model.N1tN1[j, j] = model.N1tN1[j, j] - abs2(λj) * abs2(model.storage_d_1[j]) + model.V_rank[1] * λj
181+
end
182+
model.N1tN1 .= model.N1tN1 ./ model.V_rank[1]
183+
mul!(model.storage_d_d_1, model.N1tN1, Φinv)
184+
mul!(model.Σ[1], transpose(Φinv), model.storage_d_d_1)
185+
# update Σ[2]
186+
@inbounds for j in 1:d
187+
λj = model.Λ[j]
188+
model.N2tN2[j, j] = model.N2tN2[j, j] - abs2(model.storage_d_2[j]) + model.V_rank[2]
189+
end
190+
model.N2tN2 .= model.N2tN2 ./ model.V_rank[2]
191+
mul!(model.storage_d_d_1, model.N2tN2, Φinv)
192+
mul!(model.Σ[2], transpose(Φinv), model.storage_d_d_1)
169193
end
170-
lmul!(Diagonal(one(T) ./ model.storage_d_2), vecs)
171-
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
172-
mul!(model.Σ[2], model.storage_d_d_1, transpose(model.storage_d_d_1))
173-
model.Σ
174194
end
175195

176196
function update_Φ!(
177-
model :: MRTVCModel{T};
197+
model :: MRTVCModel{T}
178198
) where T <: BlasReal
179199
copy!(model.storage_d_d_1, model.Σ[1])
180200
copy!(model.storage_d_d_2, model.Σ[2])

src/fit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ reltol::Real relative tolerance for convergence; default 1e-6
1111
verbose::Bool display algorithmic information; default true
1212
init::Symbol initialization strategy; :default initializes by least squares, while
1313
:user uses user-supplied values at model.B and model.Σ
14-
algo::Symbol optimization algorithm; :MM (default) or :EM (for MRVCModel)
14+
algo::Symbol optimization algorithm; :MM (default) or :EM
1515
log::Bool record iterate history or not; default false
1616
```
1717

test/eigen_test.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,29 @@ model = MRVCModel(Y, X, V, reml = true)
8888
# @test norm(model2.Σcov - model.Σcov) ≈ 8.216977255740531e-16
8989
end
9090

91+
model2 = MRTVCModel(Y, X, V)
92+
model = MRVCModel(Y, X, V)
93+
94+
@testset "fit! two component by MLE with EM" begin
95+
MRVCModels.fit!(model2, algo = :EM, maxiter = 500)
96+
MRVCModels.fit!(model, algo = :EM, maxiter = 500)
97+
println("||B̂_MRTVCModel - B̂_MRVCModel|| = $(norm(model2.B - model.B))")
98+
for k in 1:m
99+
println("||Σ̂[$k]_MRTVCModel - Σ̂[$k]_MRVCModel|| = $(norm(model2.Σ[k] - model.Σ[k]))")
100+
end
101+
println("||logl_MRTVCModel - logl_MRVCModel|| = $(abs2(model2.logl[1] - model.logl[1]))")
102+
println("||Bcov_MRTVCModel - Bcov_MRVCModel|| = $(norm(model2.Bcov - model.Bcov))")
103+
println("||Σcov_MRTVCModel - Σcov_MRVCModel|| = $(norm(model2.Σcov - model.Σcov))")
104+
println("||B_true - B̂|| = $(norm(B_true - model.B))")
105+
for k in 1:m
106+
println("||Σ_true[$k] - Σ̂[$k]|| = $(norm(Σ_true[k] - model.Σ[k]))")
107+
end
108+
# @test norm(model2.B - model.B) ≈ 7.349610254339635e-6
109+
# @test norm(model2.Σ[1] - model.Σ[1]) ≈ 1.1257198865106154e-5
110+
# @test norm(model2.Σ[2] - model.Σ[2]) ≈ 1.012150518153817e-6
111+
# @test abs2(model2.logl[1] - model.logl[1]) ≈ 1.19209171506449e-8
112+
# @test norm(model2.Bcov - model.Bcov) ≈ 2.9426471195867426e-6
113+
# @test norm(model2.Σcov - model.Σcov) ≈ 5.009462139695597e-7
114+
end
115+
91116
end

0 commit comments

Comments
 (0)