Skip to content

Commit a21a7d6

Browse files
authored
Merge pull request #1162 from srinjoy933/fix-tridiagonal-edge-cases
Fix out-of-bounds access for 1x1 tridiagonal matrices and add shape checks
2 parents a0471de + d54c9cf commit a21a7d6

2 files changed

Lines changed: 131 additions & 28 deletions

File tree

src/specialmatrices/stdlib_specialmatrices_tridiagonal.fypp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
2424
!! tridiagonal matrix elements.
2525
type(tridiagonal_${s1}$_type) :: A
2626
!! Corresponding tridiagonal matrix.
27-
2827
call build_tridiagonal(dl, dv, du, A)
2928
end function
3029

@@ -36,7 +35,6 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
3635
!! Matrix dimension.
3736
type(tridiagonal_${s1}$_type) :: A
3837
!! Corresponding tridiagonal matrix.
39-
4038
call build_tridiagonal(dl, dv, du, n, A)
4139
end function
4240

@@ -63,7 +61,6 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
6361
!! Error handling
6462
type(tridiagonal_${s1}$_type) :: A
6563
!! Corresponding tridiagonal matrix.
66-
6764
call build_tridiagonal(dl, dv, du, n, A, err)
6865
end function
6966
#:endfor
@@ -220,16 +217,22 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
220217
#:else
221218
allocate(B(n, n), source=zero_${k1}$)
222219
#:endif
223-
B(1, 1) = A%dv(1)
224-
B(1, 2) = A%du(1)
225-
do concurrent (i=2:n-1)
226-
B(i, i-1) = A%dl(i-1)
227-
B(i, i) = A%dv(i)
228-
B(i, i+1) = A%du(i)
229-
enddo
230-
B(n, n-1) = A%dl(n-1)
231-
B(n, n) = A%dv(n)
220+
221+
if (n == 1) then
222+
B(1, 1) = A%dv(1)
223+
else
224+
B(1, 1) = A%dv(1)
225+
B(1, 2) = A%du(1)
226+
do concurrent (i=2:n-1)
227+
B(i, i-1) = A%dl(i-1)
228+
B(i, i) = A%dv(i)
229+
B(i, i+1) = A%du(i)
230+
enddo
231+
B(n, n-1) = A%dl(n-1)
232+
B(n, n) = A%dv(n)
233+
end if
232234
end associate
235+
233236
end function
234237
#:endfor
235238

@@ -282,6 +285,15 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
282285
type(tridiagonal_${s1}$_type), intent(in) :: A
283286
type(tridiagonal_${s1}$_type), intent(in) :: B
284287
type(tridiagonal_${s1}$_type) :: C
288+
289+
! Internal variables.
290+
type(linalg_state_type) :: err0
291+
292+
if (A%n /= B%n) then
293+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "tridiagonal matrices must have the same dimension to be added")
294+
call linalg_error_handling(err0)
295+
end if
296+
285297
C = tridiagonal(A%dl, A%dv, A%du)
286298
C%dl = C%dl + B%dl
287299
C%dv = C%dv + B%dv
@@ -292,11 +304,20 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
292304
type(tridiagonal_${s1}$_type), intent(in) :: A
293305
type(tridiagonal_${s1}$_type), intent(in) :: B
294306
type(tridiagonal_${s1}$_type) :: C
307+
308+
! Internal variables.
309+
type(linalg_state_type) :: err0
310+
311+
if (A%n /= B%n) then
312+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "tridiagonal matrices must have the same dimension to be subtracted")
313+
call linalg_error_handling(err0)
314+
end if
315+
295316
C = tridiagonal(A%dl, A%dv, A%du)
296317
C%dl = C%dl - B%dl
297318
C%dv = C%dv - B%dv
298319
C%du = C%du - B%du
299320
end function
300321
#:endfor
301322

302-
end submodule
323+
end submodule

test/linalg/test_linalg_specialmatrices.fypp

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@ module test_specialmatrices
66
use stdlib_kinds
77
use stdlib_linalg, only: hermitian
88
use stdlib_linalg_state, only: linalg_state_type
9-
use stdlib_math, only: all_close
9+
use stdlib_math, only: all_close, is_close
1010
use stdlib_specialmatrices
1111
use stdlib_strings, only: to_string
1212
implicit none
13-
1413
contains
1514

16-
1715
!> Collect all exported unit tests
1816
subroutine collect_suite(testsuite)
1917
!> Collection of tests
@@ -25,14 +23,16 @@ contains
2523
new_unittest('sym_tridiagonal', test_sym_tridiagonal), &
2624
new_unittest('sym_tridiagonal error handling', test_sym_tridiagonal_error_handling), &
2725
new_unittest('symmetric tridiagonal 1x1 dense', test_sym_tridiagonal_1x1), &
28-
new_unittest('symmetric tridiagonal arithmetic', test_sym_tridiagonal_arithmetic) &
26+
new_unittest('symmetric tridiagonal arithmetic', test_sym_tridiagonal_arithmetic), &
27+
new_unittest('tridiagonal 1x1 edge case', test_tridiagonal_1x1), &
28+
new_unittest('tridiagonal arithmetic', test_tridiagonal_arithmetic) &
2929
]
3030
end subroutine
3131

3232
subroutine test_tridiagonal(error)
3333
!> Error handling
3434
type(error_type), allocatable, intent(out) :: error
35-
#:for k1, t1, s1 in (KINDS_TYPES)
35+
#:for k1, t1, s1 in KINDS_TYPES
3636
block
3737
integer, parameter :: wp = ${k1}$
3838
integer, parameter :: n = 5
@@ -113,7 +113,7 @@ contains
113113
"spmv(fail): y = alpha*A*x + beta*y, alpha: "//to_string(alpha)//", beta: "//to_string(beta))
114114
if (allocated(error)) return
115115

116-
! Test y = alpha * A.T @ x beta * y for random values of alpha and beta
116+
! Test y = alpha * A.T @ x + beta * y for random values of alpha and beta
117117
y1 = 0.0_wp
118118
call random_number(alpha)
119119
call random_number(beta)
@@ -144,7 +144,7 @@ contains
144144
subroutine test_sym_tridiagonal(error)
145145
!> Error handling
146146
type(error_type), allocatable, intent(out) :: error
147-
#:for k1, t1, s1 in (KINDS_TYPES)
147+
#:for k1, t1, s1 in KINDS_TYPES
148148
block
149149
integer, parameter :: wp = ${k1}$
150150
integer, parameter :: n = 5
@@ -255,13 +255,14 @@ contains
255255
subroutine test_tridiagonal_error_handling(error)
256256
!> Error handling
257257
type(error_type), allocatable, intent(out) :: error
258-
#:for k1, t1, s1 in (KINDS_TYPES)
258+
#:for k1, t1, s1 in KINDS_TYPES
259259
block
260260
integer, parameter :: wp = ${k1}$
261261
integer, parameter :: n = 5
262262
type(tridiagonal_${s1}$_type) :: A
263-
${t1}$, allocatable :: dl(:), dv(:), du(:)
264263
type(linalg_state_type) :: state
264+
265+
${t1}$, allocatable :: dl(:), du(:), dv(:)
265266
integer :: i
266267

267268
!> Test constructor from arrays.
@@ -283,7 +284,7 @@ contains
283284
subroutine test_sym_tridiagonal_error_handling(error)
284285
!> Error handling
285286
type(error_type), allocatable, intent(out) :: error
286-
#:for k1, t1, s1 in (KINDS_TYPES)
287+
#:for k1, t1, s1 in KINDS_TYPES
287288
block
288289
integer, parameter :: wp = ${k1}$
289290
integer, parameter :: n = 5
@@ -312,7 +313,7 @@ contains
312313
subroutine test_sym_tridiagonal_1x1(error)
313314
!> Error handling
314315
type(error_type), allocatable, intent(out) :: error
315-
#:for k1, t1, s1 in (KINDS_TYPES)
316+
#:for k1, t1, s1 in KINDS_TYPES
316317
block
317318
integer, parameter :: wp = ${k1}$
318319
type(sym_tridiagonal_${s1}$_type) :: A
@@ -337,7 +338,7 @@ contains
337338
subroutine test_sym_tridiagonal_arithmetic(error)
338339
!> Error handling
339340
type(error_type), allocatable, intent(out) :: error
340-
#:for k1, t1, s1 in (KINDS_TYPES)
341+
#:for k1, t1, s1 in KINDS_TYPES
341342
block
342343
integer, parameter :: wp = ${k1}$
343344
type(sym_tridiagonal_${s1}$_type) :: A, B, C
@@ -368,8 +369,89 @@ contains
368369
end block
369370
#:endfor
370371
end subroutine
371-
end module
372372

373+
subroutine test_tridiagonal_1x1(error)
374+
!> Test 1x1 matrix edge case for dense conversion
375+
type(error_type), allocatable, intent(out) :: error
376+
#:for k1, t1, s1 in KINDS_TYPES
377+
block
378+
integer, parameter :: wp = ${k1}$
379+
type(tridiagonal_${s1}$_type) :: A
380+
${t1}$, allocatable :: Amat(:,:)
381+
382+
#:if t1.startswith('complex')
383+
${t1}$, parameter :: dl(0) = [${t1}$ ::]
384+
${t1}$, parameter :: du(0) = [${t1}$ ::]
385+
${t1}$, parameter :: dv(1) = [cmplx(5.0_wp, 0.0_wp, kind=wp)]
386+
#:else
387+
${t1}$, parameter :: dl(0) = [${t1}$ ::]
388+
${t1}$, parameter :: du(0) = [${t1}$ ::]
389+
${t1}$, parameter :: dv(1) = [5.0_wp]
390+
#:endif
391+
392+
A = tridiagonal(dl, dv, du)
393+
Amat = dense(A)
394+
395+
! Check if the 1x1 matrix converted properly at runtime without segfaulting
396+
call check(error, size(Amat, 1) == 1, .true.)
397+
if (allocated(error)) return
398+
call check(error, size(Amat, 2) == 1, .true.)
399+
if (allocated(error)) return
400+
call check(error, is_close(Amat(1,1), 5.0_wp), .true.)
401+
if (allocated(error)) return
402+
end block
403+
#:endfor
404+
end subroutine
405+
406+
subroutine test_tridiagonal_arithmetic(error)
407+
!> Test arithmetic operations and optimization
408+
type(error_type), allocatable, intent(out) :: error
409+
#:for k1, t1, s1 in KINDS_TYPES
410+
block
411+
integer, parameter :: wp = ${k1}$
412+
integer, parameter :: n = 3
413+
type(tridiagonal_${s1}$_type) :: A, B, C
414+
415+
#:if t1.startswith('complex')
416+
${t1}$, parameter :: dl1(n-1) = [cmplx(1.0_wp, 0.0_wp, kind=wp), cmplx(1.0_wp, 0.0_wp, kind=wp)]
417+
${t1}$, parameter :: dv1(n) = [cmplx(2.0_wp, 0.0_wp, kind=wp), cmplx(2.0_wp, 0.0_wp, kind=wp), cmplx(2.0_wp, 0.0_wp, kind=wp)]
418+
${t1}$, parameter :: du1(n-1) = [cmplx(3.0_wp, 0.0_wp, kind=wp), cmplx(3.0_wp, 0.0_wp, kind=wp)]
419+
420+
${t1}$, parameter :: dl2(n-1) = [cmplx(4.0_wp, 0.0_wp, kind=wp), cmplx(4.0_wp, 0.0_wp, kind=wp)]
421+
${t1}$, parameter :: dv2(n) = [cmplx(5.0_wp, 0.0_wp, kind=wp), cmplx(5.0_wp, 0.0_wp, kind=wp), cmplx(5.0_wp, 0.0_wp, kind=wp)]
422+
${t1}$, parameter :: du2(n-1) = [cmplx(6.0_wp, 0.0_wp, kind=wp), cmplx(6.0_wp, 0.0_wp, kind=wp)]
423+
#:else
424+
${t1}$, parameter :: dl1(n-1) = [1.0_wp, 1.0_wp]
425+
${t1}$, parameter :: dv1(n) = [2.0_wp, 2.0_wp, 2.0_wp]
426+
${t1}$, parameter :: du1(n-1) = [3.0_wp, 3.0_wp]
427+
428+
${t1}$, parameter :: dl2(n-1) = [4.0_wp, 4.0_wp]
429+
${t1}$, parameter :: dv2(n) = [5.0_wp, 5.0_wp, 5.0_wp]
430+
${t1}$, parameter :: du2(n-1) = [6.0_wp, 6.0_wp]
431+
#:endif
432+
433+
A = tridiagonal(dl1, dv1, du1)
434+
B = tridiagonal(dl2, dv2, du2)
435+
436+
! Addition test - use dense() to bypass private component restrictions
437+
C = A + B
438+
call check(error, all_close(dense(C), dense(A) + dense(B)), .true.)
439+
if (allocated(error)) return
440+
441+
! Subtraction test
442+
C = A - B
443+
call check(error, all_close(dense(C), dense(A) - dense(B)), .true.)
444+
if (allocated(error)) return
445+
446+
! Scalar multiplication test
447+
C = 3.0_wp * A
448+
call check(error, all_close(dense(C), 3.0_wp * dense(A)), .true.)
449+
if (allocated(error)) return
450+
end block
451+
#:endfor
452+
end subroutine
453+
454+
end module
373455

374456
program tester
375457
use, intrinsic :: iso_fortran_env, only : error_unit
@@ -383,7 +465,7 @@ program tester
383465
stat = 0
384466

385467
testsuites = [ &
386-
new_testsuite("specialmatrices", collect_suite) &
468+
new_testsuite("special_matrices", collect_suite) &
387469
]
388470

389471
do is = 1, size(testsuites)
@@ -395,4 +477,4 @@ program tester
395477
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
396478
error stop
397479
end if
398-
end program
480+
end program

0 commit comments

Comments
 (0)