Skip to content

Commit df3f72e

Browse files
authored
Add some missed types to WrappedArray (#95)
1 parent cf6f6d4 commit df3f72e

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/wrappers.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ adapt_structure(to, A::LinearAlgebra.Diagonal) =
4848
LinearAlgebra.Diagonal(adapt(to, parent(A)))
4949
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
5050
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
51+
adapt_structure(to, A::LinearAlgebra.Bidiagonal) =
52+
LinearAlgebra.Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), A.uplo)
5153
adapt_structure(to, A::LinearAlgebra.Symmetric) =
5254
LinearAlgebra.Symmetric(adapt(to, parent(A)))
55+
adapt_structure(to, A::LinearAlgebra.Hermitian) =
56+
LinearAlgebra.Hermitian(adapt(to, parent(A)))
5357

5458

5559
# we generally don't support multiple layers of wrappers, but some occur often
@@ -100,8 +104,10 @@ const WrappedArray{T,N,Src,Dst} = Union{
100104
LinearAlgebra.UpperTriangular{T,<:Dst},
101105
LinearAlgebra.UnitUpperTriangular{T,<:Dst},
102106
LinearAlgebra.Diagonal{T,<:Dst},
107+
LinearAlgebra.Bidiagonal{T,<:Dst},
103108
LinearAlgebra.Tridiagonal{T,<:Dst},
104109
LinearAlgebra.Symmetric{T,<:Dst},
110+
LinearAlgebra.Hermitian{T,<:Dst},
105111

106112
WrappedReinterpretArray{T,N,<:Src},
107113
WrappedReshapedArray{T,N,<:Src},

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,15 @@ using LinearAlgebra
178178
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray
179179
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray
180180
@test_adapt CustomArray Symmetric(mat.arr) Symmetric(mat) AnyCustomArray
181-
181+
@test_adapt CustomArray Hermitian(mat.arr) Hermitian(mat) AnyCustomArray
182+
182183
@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray
183184

184185
dl = CustomArray{Float64,1}(rand(2))
185186
du = CustomArray{Float64,1}(rand(2))
186187
d = CustomArray{Float64,1}(rand(3))
187188
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray
188-
189+
@test_adapt CustomArray Bidiagonal(d.arr, du.arr, 'U') Bidiagonal(d, du, 'U') AnyCustomArray
189190
end
190191

191192

0 commit comments

Comments
 (0)