Skip to content

Commit d6b8572

Browse files
committed
add: handle 1x1 case, test for 1x1 and test for arithmetic operators
1 parent 32f31dc commit d6b8572

2 files changed

Lines changed: 75 additions & 10 deletions

File tree

src/specialmatrices/stdlib_specialmatrices_sym_tridiagonal.fypp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,19 @@ submodule (stdlib_specialmatrices) sym_tridiagonal_matrices
211211
#:else
212212
allocate(B(n,n), source=zero_${k1}$)
213213
#:endif
214-
B(1,1) = A%dv(1)
215-
B(1,2) = A%du(1)
216-
do concurrent (i = 2: n - 1)
217-
B(i, i - 1) = A%du(i - 1)
218-
B(i, i) = A%dv(i)
219-
B(i, i + 1) = A%du(i)
220-
enddo
221-
B(n , n -1) = A%du(n - 1)
222-
B(n, n) = A%dv(n)
214+
if(n == 1) then
215+
B(1,1) = A%dv(1)
216+
else
217+
B(1,1) = A%dv(1)
218+
B(1,2) = A%du(1)
219+
do concurrent (i = 2: n - 1)
220+
B(i, i - 1) = A%du(i - 1)
221+
B(i, i) = A%dv(i)
222+
B(i, i + 1) = A%du(i)
223+
enddo
224+
B(n , n -1) = A%du(n - 1)
225+
B(n, n) = A%dv(n)
226+
end if
223227
end associate
224228
end function
225229
#:endfor

test/linalg/test_linalg_specialmatrices.fypp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ contains
2323
new_unittest('tridiagonal', test_tridiagonal), &
2424
new_unittest('tridiagonal error handling', test_tridiagonal_error_handling), &
2525
new_unittest('sym_tridiagonal', test_sym_tridiagonal), &
26-
new_unittest('sym_tridiagonal error handling', test_sym_tridiagonal_error_handling) &
26+
new_unittest('sym_tridiagonal error handling', test_sym_tridiagonal_error_handling), &
27+
new_unittest('symmetric tridiagonal 1x1 dense', test_sym_tridiagonal_1x1), &
28+
new_unittest('symmetric tridiagonal arithmetic', test_sym_tridiagonal_arithmetic) &
2729
]
2830
end subroutine
2931

@@ -307,6 +309,65 @@ contains
307309
#:endfor
308310
end subroutine
309311

312+
subroutine test_sym_tridiagonal_1x1(error)
313+
!> Error handling
314+
type(error_type), allocatable, intent(out) :: error
315+
#:for k1, t1, s1 in (KINDS_TYPES)
316+
block
317+
integer, parameter :: wp = ${k1}$
318+
type(sym_tridiagonal_${s1}$_type) :: A
319+
${t1}$, allocatable :: B(:, :)
320+
${t1}$, allocatable :: dv(:), du(:)
321+
${t1}$ :: C(1,1)
322+
323+
allocate(dv(1), du(0))
324+
dv = [5.0_wp]
325+
326+
A = sym_tridiagonal(du, dv)
327+
B = dense(A)
328+
C(1,1) = 5.0_wp
329+
330+
call check(error, all_close(B, C), .true., &
331+
"Symmetric tridiagonal dense function failed (n=1)")
332+
if (allocated(error)) return
333+
end block
334+
#:endfor
335+
end subroutine
336+
337+
subroutine test_sym_tridiagonal_arithmetic(error)
338+
!> Error handling
339+
type(error_type), allocatable, intent(out) :: error
340+
#:for k1, t1, s1 in (KINDS_TYPES)
341+
block
342+
integer, parameter :: wp = ${k1}$
343+
type(sym_tridiagonal_${s1}$_type) :: A, B, C
344+
${t1}$, allocatable :: dv(:), du(:)
345+
346+
dv = [1.0_wp, 5.0_wp, 9.0_wp, 13.0_wp]
347+
du = [2.0_wp, 6.0_wp, 10.0_wp]
348+
A = sym_tridiagonal(du, dv)
349+
350+
dv = [3.0_wp, 7.0_wp, 11.0_wp, 14.0_wp]
351+
du = [4.0_wp, 8.0_wp, 12.0_wp]
352+
B = sym_tridiagonal(du, dv)
353+
354+
C = A + B
355+
call check(error, all_close(dense(C), dense(A) + dense(B)), .true., &
356+
"Symmetric tridiagonal operator + failed")
357+
if (allocated(error)) return
358+
359+
C = A - B
360+
call check(error, all_close(dense(C), dense(A) - dense(B)), .true., &
361+
"Symmetric tridiagonal operator - failed")
362+
if (allocated(error)) return
363+
364+
C = 5.0_wp * A
365+
call check(error, all_close(dense(C), 5.0_wp * dense(A)), .true., &
366+
"Symmetric tridiagonal operator * failed")
367+
if (allocated(error)) return
368+
end block
369+
#:endfor
370+
end subroutine
310371
end module
311372

312373

0 commit comments

Comments
 (0)