@@ -6,14 +6,12 @@ module test_specialmatrices
66 use stdlib_kinds
77 use stdlib_linalg, only: hermitian
88 use stdlib_linalg_state, only: linalg_state_type
9- use stdlib_math, only: all_close
9+ use stdlib_math, only: all_close, is_close
1010 use stdlib_specialmatrices
1111 use stdlib_strings, only: to_string
1212 implicit none
13-
1413contains
1514
16-
1715 !> Collect all exported unit tests
1816 subroutine collect_suite(testsuite)
1917 !> Collection of tests
@@ -25,14 +23,16 @@ contains
2523 new_unittest('sym_tridiagonal', test_sym_tridiagonal), &
2624 new_unittest('sym_tridiagonal error handling', test_sym_tridiagonal_error_handling), &
2725 new_unittest('symmetric tridiagonal 1x1 dense', test_sym_tridiagonal_1x1), &
28- new_unittest('symmetric tridiagonal arithmetic', test_sym_tridiagonal_arithmetic) &
26+ new_unittest('symmetric tridiagonal arithmetic', test_sym_tridiagonal_arithmetic), &
27+ new_unittest('tridiagonal 1x1 edge case', test_tridiagonal_1x1), &
28+ new_unittest('tridiagonal arithmetic', test_tridiagonal_arithmetic) &
2929 ]
3030 end subroutine
3131
3232 subroutine test_tridiagonal(error)
3333 !> Error handling
3434 type(error_type), allocatable, intent(out) :: error
35- #:for k1, t1, s1 in ( KINDS_TYPES)
35+ #:for k1, t1, s1 in KINDS_TYPES
3636 block
3737 integer, parameter :: wp = ${k1}$
3838 integer, parameter :: n = 5
@@ -113,7 +113,7 @@ contains
113113 "spmv(fail): y = alpha*A*x + beta*y, alpha: "//to_string(alpha)//", beta: "//to_string(beta))
114114 if (allocated(error)) return
115115
116- ! Test y = alpha * A.T @ x beta * y for random values of alpha and beta
116+ ! Test y = alpha * A.T @ x + beta * y for random values of alpha and beta
117117 y1 = 0.0_wp
118118 call random_number(alpha)
119119 call random_number(beta)
@@ -144,7 +144,7 @@ contains
144144 subroutine test_sym_tridiagonal(error)
145145 !> Error handling
146146 type(error_type), allocatable, intent(out) :: error
147- #:for k1, t1, s1 in ( KINDS_TYPES)
147+ #:for k1, t1, s1 in KINDS_TYPES
148148 block
149149 integer, parameter :: wp = ${k1}$
150150 integer, parameter :: n = 5
@@ -255,13 +255,14 @@ contains
255255 subroutine test_tridiagonal_error_handling(error)
256256 !> Error handling
257257 type(error_type), allocatable, intent(out) :: error
258- #:for k1, t1, s1 in ( KINDS_TYPES)
258+ #:for k1, t1, s1 in KINDS_TYPES
259259 block
260260 integer, parameter :: wp = ${k1}$
261261 integer, parameter :: n = 5
262262 type(tridiagonal_${s1}$_type) :: A
263- ${t1}$, allocatable :: dl(:), dv(:), du(:)
264263 type(linalg_state_type) :: state
264+
265+ ${t1}$, allocatable :: dl(:), du(:), dv(:)
265266 integer :: i
266267
267268 !> Test constructor from arrays.
@@ -283,7 +284,7 @@ contains
283284 subroutine test_sym_tridiagonal_error_handling(error)
284285 !> Error handling
285286 type(error_type), allocatable, intent(out) :: error
286- #:for k1, t1, s1 in ( KINDS_TYPES)
287+ #:for k1, t1, s1 in KINDS_TYPES
287288 block
288289 integer, parameter :: wp = ${k1}$
289290 integer, parameter :: n = 5
@@ -312,7 +313,7 @@ contains
312313 subroutine test_sym_tridiagonal_1x1(error)
313314 !> Error handling
314315 type(error_type), allocatable, intent(out) :: error
315- #:for k1, t1, s1 in ( KINDS_TYPES)
316+ #:for k1, t1, s1 in KINDS_TYPES
316317 block
317318 integer, parameter :: wp = ${k1}$
318319 type(sym_tridiagonal_${s1}$_type) :: A
@@ -337,7 +338,7 @@ contains
337338 subroutine test_sym_tridiagonal_arithmetic(error)
338339 !> Error handling
339340 type(error_type), allocatable, intent(out) :: error
340- #:for k1, t1, s1 in ( KINDS_TYPES)
341+ #:for k1, t1, s1 in KINDS_TYPES
341342 block
342343 integer, parameter :: wp = ${k1}$
343344 type(sym_tridiagonal_${s1}$_type) :: A, B, C
@@ -368,8 +369,89 @@ contains
368369 end block
369370 #:endfor
370371 end subroutine
371- end module
372372
373+ subroutine test_tridiagonal_1x1(error)
374+ !> Test 1x1 matrix edge case for dense conversion
375+ type(error_type), allocatable, intent(out) :: error
376+ #:for k1, t1, s1 in KINDS_TYPES
377+ block
378+ integer, parameter :: wp = ${k1}$
379+ type(tridiagonal_${s1}$_type) :: A
380+ ${t1}$, allocatable :: Amat(:,:)
381+
382+ #:if t1.startswith('complex')
383+ ${t1}$, parameter :: dl(0) = [${t1}$ ::]
384+ ${t1}$, parameter :: du(0) = [${t1}$ ::]
385+ ${t1}$, parameter :: dv(1) = [cmplx(5.0_wp, 0.0_wp, kind=wp)]
386+ #:else
387+ ${t1}$, parameter :: dl(0) = [${t1}$ ::]
388+ ${t1}$, parameter :: du(0) = [${t1}$ ::]
389+ ${t1}$, parameter :: dv(1) = [5.0_wp]
390+ #:endif
391+
392+ A = tridiagonal(dl, dv, du)
393+ Amat = dense(A)
394+
395+ ! Check if the 1x1 matrix converted properly at runtime without segfaulting
396+ call check(error, size(Amat, 1) == 1, .true.)
397+ if (allocated(error)) return
398+ call check(error, size(Amat, 2) == 1, .true.)
399+ if (allocated(error)) return
400+ call check(error, is_close(Amat(1,1), 5.0_wp), .true.)
401+ if (allocated(error)) return
402+ end block
403+ #:endfor
404+ end subroutine
405+
406+ subroutine test_tridiagonal_arithmetic(error)
407+ !> Test arithmetic operations and optimization
408+ type(error_type), allocatable, intent(out) :: error
409+ #:for k1, t1, s1 in KINDS_TYPES
410+ block
411+ integer, parameter :: wp = ${k1}$
412+ integer, parameter :: n = 3
413+ type(tridiagonal_${s1}$_type) :: A, B, C
414+
415+ #:if t1.startswith('complex')
416+ ${t1}$, parameter :: dl1(n-1) = [cmplx(1.0_wp, 0.0_wp, kind=wp), cmplx(1.0_wp, 0.0_wp, kind=wp)]
417+ ${t1}$, parameter :: dv1(n) = [cmplx(2.0_wp, 0.0_wp, kind=wp), cmplx(2.0_wp, 0.0_wp, kind=wp), cmplx(2.0_wp, 0.0_wp, kind=wp)]
418+ ${t1}$, parameter :: du1(n-1) = [cmplx(3.0_wp, 0.0_wp, kind=wp), cmplx(3.0_wp, 0.0_wp, kind=wp)]
419+
420+ ${t1}$, parameter :: dl2(n-1) = [cmplx(4.0_wp, 0.0_wp, kind=wp), cmplx(4.0_wp, 0.0_wp, kind=wp)]
421+ ${t1}$, parameter :: dv2(n) = [cmplx(5.0_wp, 0.0_wp, kind=wp), cmplx(5.0_wp, 0.0_wp, kind=wp), cmplx(5.0_wp, 0.0_wp, kind=wp)]
422+ ${t1}$, parameter :: du2(n-1) = [cmplx(6.0_wp, 0.0_wp, kind=wp), cmplx(6.0_wp, 0.0_wp, kind=wp)]
423+ #:else
424+ ${t1}$, parameter :: dl1(n-1) = [1.0_wp, 1.0_wp]
425+ ${t1}$, parameter :: dv1(n) = [2.0_wp, 2.0_wp, 2.0_wp]
426+ ${t1}$, parameter :: du1(n-1) = [3.0_wp, 3.0_wp]
427+
428+ ${t1}$, parameter :: dl2(n-1) = [4.0_wp, 4.0_wp]
429+ ${t1}$, parameter :: dv2(n) = [5.0_wp, 5.0_wp, 5.0_wp]
430+ ${t1}$, parameter :: du2(n-1) = [6.0_wp, 6.0_wp]
431+ #:endif
432+
433+ A = tridiagonal(dl1, dv1, du1)
434+ B = tridiagonal(dl2, dv2, du2)
435+
436+ ! Addition test - use dense() to bypass private component restrictions
437+ C = A + B
438+ call check(error, all_close(dense(C), dense(A) + dense(B)), .true.)
439+ if (allocated(error)) return
440+
441+ ! Subtraction test
442+ C = A - B
443+ call check(error, all_close(dense(C), dense(A) - dense(B)), .true.)
444+ if (allocated(error)) return
445+
446+ ! Scalar multiplication test
447+ C = 3.0_wp * A
448+ call check(error, all_close(dense(C), 3.0_wp * dense(A)), .true.)
449+ if (allocated(error)) return
450+ end block
451+ #:endfor
452+ end subroutine
453+
454+ end module
373455
374456program tester
375457 use, intrinsic :: iso_fortran_env, only : error_unit
@@ -383,7 +465,7 @@ program tester
383465 stat = 0
384466
385467 testsuites = [ &
386- new_testsuite("specialmatrices ", collect_suite) &
468+ new_testsuite("special_matrices ", collect_suite) &
387469 ]
388470
389471 do is = 1, size(testsuites)
@@ -395,4 +477,4 @@ program tester
395477 write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
396478 error stop
397479 end if
398- end program
480+ end program
0 commit comments