Skip to content

Commit cdd76ab

Browse files
committed
more test
1 parent 42eca54 commit cdd76ab

8 files changed

Lines changed: 110 additions & 24 deletions

File tree

docs/src/advanced.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ In this setting, the Fisher information matrix is equivalent to
6868
where ``\boldsymbol{W}_{ij}`` is the ``d \times d`` matrix that has entries
6969
```math
7070
\begin{aligned}
71-
(\boldsymbol{W}_{11})_{kl} &= \text{tr}(\boldsymbol{D}^2(\boldsymbol{\lambda}_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\boldsymbol{\lambda}_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}) \\
72-
(\boldsymbol{W}_{12})_{kl} &= \text{tr}(\boldsymbol{D}(\boldsymbol{\lambda}_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\boldsymbol{\lambda}_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}) \\
73-
(\boldsymbol{W}_{22})_{kl} &= \text{tr}((\boldsymbol{\lambda}_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\boldsymbol{\lambda}_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}).
71+
(\boldsymbol{W}_{11})_{kl} &= \text{tr}(\boldsymbol{D}^2(\lambda_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\lambda_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}) \\
72+
(\boldsymbol{W}_{12})_{kl} &= \text{tr}(\boldsymbol{D}(\lambda_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\lambda_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}) \\
73+
(\boldsymbol{W}_{22})_{kl} &= \text{tr}((\lambda_k \boldsymbol{D} + \boldsymbol{I}_n)^{-1}(\lambda_l \boldsymbol{D} + \boldsymbol{I}_n)^{-1}).
7474
\end{aligned}
7575
```

src/MultiResponseVarianceComponentModels.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export VCModel,
2929
lrt,
3030
h2,
3131
rg,
32-
# multivariate_calculus.jl
32+
# mvcalculus.jl
3333
kron_axpy!,
3434
kron_reduction!,
3535
vech,
@@ -528,7 +528,7 @@ function Base.show(io::IO, model::VCModel)
528528
printstyled(io, "$m"; color = :yellow)
529529
end
530530

531-
include("multivariate_calculus.jl")
531+
include("mvcalculus.jl")
532532
include("reml.jl")
533533
include("fit.jl")
534534
include("eigen.jl")

src/eigen.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ function fit!(
8484
break
8585
end
8686
if abs(logl - logl_prev) < reltol * (abs(logl_prev) + 1)
87-
@info "Updates converged!"
8887
copyto!(model.logl, logl)
8988
IterativeSolvers.setconv(history, true)
9089
if model.se

src/fit.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ function fit!(
123123
break
124124
end
125125
if abs(logl - logl_prev) < reltol * (abs(logl_prev) + 1)
126-
@info "Updates converged!"
127126
copyto!(model.logl, logl)
128127
IterativeSolvers.setconv(history, true)
129128
if model.se

src/reml.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,17 @@ function project_null(
44
V :: Vector{<:AbstractMatrix{T}}
55
) where {T <: Real}
66
n, p, m = size(X, 1), size(X, 2), length(V)
7-
if isempty(X)
8-
Y, V, Matrix{T}(I, n, n)
9-
else
10-
# basis of N(Xᵗ)
11-
Xᵗ = Matrix{T}(undef, size(X, 2), size(X, 1))
12-
transpose!(Xᵗ, X)
13-
A = nullspace(Xᵗ)
14-
s = size(A, 2)
15-
= transpose(A) * Y
16-
= Vector{Matrix{T}}(undef, m)
17-
storage = zeros(n, s)
18-
for i in 1:m
19-
mul!(storage, V[i], A)
20-
Ṽ[i] = BLAS.gemm('T', 'N', A, storage)
21-
end
22-
Ỹ, Ṽ, A
7+
# basis of N(Xᵗ)
8+
Xᵗ = Matrix{T}(undef, p, n)
9+
transpose!(Xᵗ, X)
10+
A = nullspace(Xᵗ)
11+
s = size(A, 2)
12+
= transpose(A) * Y
13+
= Vector{Matrix{T}}(undef, m)
14+
storage = zeros(n, s)
15+
for i in 1:m
16+
mul!(storage, V[i], A)
17+
Ṽ[i] = BLAS.gemm('T', 'N', A, storage)
2318
end
19+
Ỹ, Ṽ, A
2420
end

test/eigen_test.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
module EigenTest
2+
3+
using MultiResponseVarianceComponentModels
4+
using BenchmarkTools, LinearAlgebra, Profile, Random, StatsBase, Test
5+
6+
const MRVCModels = MultiResponseVarianceComponentModels
7+
rng = MersenneTwister(456)
8+
9+
n, p, d, m = 855, 3, 4, 2
10+
X = [ones(n) randn(rng, n, p - 1)] # design matrix including intercept
11+
# V[1] has entries i * (n - j + 1) for j ≥ i, then scaled to be a correlation matrix
12+
# V[2] is identity
13+
V = Vector{Matrix{Float64}}(undef, m)
14+
V = Vector{Matrix{Float64}}(undef, m)
15+
V[1] = [j i ? i * (n - j + 1) : j * (n - i + 1) for i in 1:n, j in 1:n]
16+
StatsBase.cov2cor!(V[1], [sqrt(V[1][i, i]) for i in 1:n])
17+
V[2] = Matrix(UniformScaling(1.0), n, n)
18+
# true parameter values
19+
B_true = 2 * rand(p, d) # uniform on [0, 2]
20+
Σ_true = [
21+
Matrix(UniformScaling(0.2), d, d),
22+
Matrix(UniformScaling(0.6), d, d)
23+
]
24+
Ω_true = zeros(n * d, n * d)
25+
for k in 1:m
26+
Ω_true .+= kron(Σ_true[k], V[k])
27+
end
28+
y = vec(X * B_true) + cholesky(Symmetric(Ω_true)).L * randn(rng, n * d)
29+
Y = reshape(y, n, d)
30+
31+
@testset "constructor two component" begin
32+
model2 = MRTVCModel(Y, X[:, 1], V)
33+
model2 = MRTVCModel(Y[:, 1], X, V)
34+
model2 = MRTVCModel(Y[:, 1], X[:, 1], V)
35+
model2 = MRTVCModel(Y, V)
36+
model2 = MRTVCModel(Y[:, 1], V)
37+
model2 = MRTVCModel(Y, X, V, se = false)
38+
model2 = MRTVCModel(Y, X, V)
39+
end
40+
41+
model2 = MRTVCModel(Y, X, V)
42+
model = MRVCModel(Y, X, V)
43+
44+
@testset "fit! two component by MLE with MM" begin
45+
MRVCModels.fit!(model2, algo = :MM, verbose = false, maxiter = 100)
46+
MRVCModels.fit!(model, algo = :MM, verbose = false, maxiter = 100)
47+
println("||B̂_MRTVCModel - B̂_MRVCModel|| = $(norm(model2.B - model.B))")
48+
for k in 1:m
49+
println("||Σ̂[$k]_MRTVCModel - Σ̂[$k]_MRVCModel|| = $(norm(model2.Σ[k] - model.Σ[k]))")
50+
end
51+
println("||logl_MRTVCModel - logl_MRVCModel|| = $(abs2(model2.logl[1] - model.logl[1]))")
52+
println("||Bcov_MRTVCModel - Bcov_MRVCModel|| = $(norm(model2.Bcov - model.Bcov))")
53+
println("||Σcov_MRTVCModel - Σcov_MRVCModel|| = $(norm(model2.Σcov - model.Σcov))")
54+
println("||B_true - B̂|| = $(norm(B_true - model.B))")
55+
for k in 1:m
56+
println("||Σ_true[$k] - Σ̂[$k]|| = $(norm(Σ_true[k] - model.Σ[k]))")
57+
end
58+
# @test norm(model2.B - model.B) < 2e-4 # 6.986221757875547e-5
59+
# @test norm(model2.Σ[1] - model.Σ[1]) < 2e-4 # 0.0001772074443345503
60+
# @test norm(model2.Σ[2] - model.Σ[2]) < 2e-4 # 1.6274369068536495e-5
61+
# @test abs2(model2.logl[1] - model.logl[1]) < 2e-4 # 4.7903903529347e-8
62+
# @test norm(model2.Bcov - model.Bcov) < 2e-4 # 4.423117207077623e-5
63+
# @test norm(model2.Σcov - model.Σcov) < 2e-4 # 9.284043626500212e-6
64+
end
65+
66+
model2 = MRTVCModel(Y, X, V, reml = true)
67+
model = MRVCModel(Y, X, V, reml = true)
68+
69+
@testset "fit! two component by REML with MM" begin
70+
MRVCModels.fit!(model2, algo = :MM, verbose = false, maxiter = 100)
71+
MRVCModels.fit!(model, algo = :MM, verbose = false, maxiter = 100)
72+
println("||B̂_MRTVCModel - B̂_MRVCModel|| = $(norm(model2.B_reml - model.B_reml))")
73+
for k in 1:m
74+
println("||Σ̂[$k]_MRTVCModel - Σ̂[$k]_MRVCModel|| = $(norm(model2.Σ[k] - model.Σ[k]))")
75+
end
76+
println("||logl_MRTVCModel - logl_MRVCModel|| = $(abs2(model2.logl[1] - model.logl[1]))")
77+
println("||Bcov_MRTVCModel - Bcov_MRVCModel|| = $(norm(model2.Bcov_reml - model.Bcov_reml))")
78+
println("||Σcov_MRTVCModel - Σcov_MRVCModel|| = $(norm(model2.Σcov - model.Σcov))")
79+
println("||B_true - B̂|| = $(norm(B_true - model.B_reml))")
80+
for k in 1:m
81+
println("||Σ_true[$k] - Σ̂[$k]|| = $(norm(Σ_true[k] - model.Σ[k]))")
82+
end
83+
# @test norm(model2.B_reml - model.B_reml) < 9e-14 # 8.886934507200367e-14
84+
# @test norm(model2.Σ[1] - model.Σ[1]) < 9e-14 # 1.9800578221407427e-14
85+
# @test norm(model2.Σ[2] - model.Σ[2]) < 9e-14 # 2.9834222365790734e-15
86+
# @test abs2(model2.logl[1] - model.logl[1]) < 9e-14 # 1.987301421658649e-22
87+
# @test norm(model2.Bcov_reml - model.Bcov_reml) < 9e-14 # 5.04991691887946e-15
88+
# @test norm(model2.Σcov - model.Σcov) < 9e-14 # 1.0311308785032728e-15
89+
end
90+
91+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
include("fit_test.jl")
1+
include("eigen_test.jl")
2+
# include("fit_test.jl")
23
# include("mvcalculus_test.jl")

0 commit comments

Comments
 (0)