From 4b3df28992d4cb8b84918c98066318e92c16d4fc Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Sat, 21 Feb 2026 13:54:00 -0500 Subject: [PATCH 1/2] Implement matrix and vector interfaces This properly defines `supports_vector_interface` and `supports_matrix_interface` for `GradVector` and `GradgenOperator`, respectively, and implement the full required interface, as checked by `check_operator` and `check_state`. --- src/grad_vector.jl | 12 ++-- src/gradgen_operator.jl | 9 ++- src/linalg.jl | 111 +++++++++++++++++++++++++++++-- test/test_interface.jl | 143 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 262 insertions(+), 13 deletions(-) diff --git a/src/grad_vector.jl b/src/grad_vector.jl index f93e0c4..fe4d529 100644 --- a/src/grad_vector.jl +++ b/src/grad_vector.jl @@ -1,5 +1,6 @@ import QuantumControl.QuantumPropagators: _exp_prop_convert_state -import QuantumControl.QuantumPropagators.Interfaces: supports_inplace +import QuantumControl.QuantumPropagators.Interfaces: + supports_inplace, supports_vector_interface @doc raw"""Extended state-vector for the dynamic gradient. @@ -68,8 +69,8 @@ in-place operations. Returns `Ψ̃`. """ -function resetgradvec!(Ψ̃::GradVector) - if supports_inplace(Ψ̃) +function resetgradvec!(Ψ̃::T) where {T<:GradVector} + if supports_inplace(T) for i in eachindex(Ψ̃.grad_states) fill!(Ψ̃.grad_states[i], 0.0) end @@ -89,4 +90,7 @@ end _exp_prop_convert_state(::GradVector) = Vector{ComplexF64} -supports_inplace(Ψ̃::GradVector) = supports_inplace(Ψ̃.state) +supports_inplace(::Type{GradVector{N,T}}) where {N,T} = supports_inplace(T) + +supports_vector_interface(::Type{GradVector{N,T}}) where {N,T} = + supports_vector_interface(T) diff --git a/src/gradgen_operator.jl b/src/gradgen_operator.jl index 53dd1c2..896ac9b 100644 --- a/src/gradgen_operator.jl +++ b/src/gradgen_operator.jl @@ -2,7 +2,8 @@ using Random: GLOBAL_RNG import QuantumControl.QuantumPropagators: _exp_prop_convert_operator import QuantumControl.QuantumPropagators.Controls: get_controls import QuantumControl.QuantumPropagators.SpectralRange: random_state -import QuantumControl.QuantumPropagators.Interfaces: supports_inplace +import QuantumControl.QuantumPropagators.Interfaces: + supports_inplace, supports_matrix_interface """Static generator for the dynamic gradient. @@ -40,4 +41,8 @@ end _exp_prop_convert_operator(::GradgenOperator) = Matrix{ComplexF64} -supports_inplace(::GradgenOperator) = true +supports_inplace(::Type{GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} = + (supports_inplace(GT) && supports_inplace(CGT)) + +supports_matrix_interface(::Type{<:GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} = + supports_matrix_interface(GT) && supports_matrix_interface(CGT) diff --git a/src/linalg.jl b/src/linalg.jl index 177c604..7336c37 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -12,6 +12,11 @@ function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector, end +function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector) + return LinearAlgebra.mul!(Φ, G, Ψ, true, false) +end + + function LinearAlgebra.lmul!(c, Ψ::GradVector) LinearAlgebra.lmul!(c, Ψ.state) for i ∈ eachindex(Ψ.grad_states) @@ -48,6 +53,11 @@ function LinearAlgebra.dot(Ψ::GradVector, Φ::GradVector) end +function LinearAlgebra.dot(Ψ::GradVector, G::GradgenOperator, Φ::GradVector) + return LinearAlgebra.dot(Ψ, G * Φ) +end + + LinearAlgebra.ishermitian(G::GradgenOperator) = false @@ -75,6 +85,11 @@ function Base.length(Ψ::GradVector) end +function Base.size(Ψ::GradVector{num_controls,T}) where {num_controls,T} + return ((num_controls + 1) * length(Ψ.state),) +end + + function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT} return (num_controls + 1) .* size(O.G) end @@ -89,17 +104,105 @@ end function Base.similar(Ψ::GradVector{num_controls,T}) where {num_controls,T} - return GradVector{num_controls,T}(similar(Ψ.state), [similar(ϕ) for ϕ ∈ Ψ.grad_states]) + state_sim = similar(Ψ.state) + grad_states_sim = [similar(ϕ) for ϕ ∈ Ψ.grad_states] + return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim) +end + +Base.similar(Ψ::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ)) + +Base.similar(Ψ::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims) + +# These definitions of `similar` exist to make ExponentialUtilities happy, but +# it's not clear at all that `similar` with a custom shape really makes sense +Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} = + Matrix{T}(undef, dims...) + +Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} = + Vector{T}(undef, dims[1]) + +function Base.getindex(Ψ::GradVector{num_controls,T}, k::Int) where {num_controls,T} + N = length(Ψ.state) + L = num_controls + block = (k - 1) ÷ N + 1 + local_k = (k - 1) % N + 1 + if block <= L + return Ψ.grad_states[block][local_k] + else + return Ψ.state[local_k] + end +end + +function Base.setindex!(Ψ::GradVector{num_controls,T}, v, k::Int) where {num_controls,T} + N = length(Ψ.state) + L = num_controls + block = (k - 1) ÷ N + 1 + local_k = (k - 1) % N + 1 + if block <= L + Ψ.grad_states[block][local_k] = v + else + Ψ.state[local_k] = v + end + return Ψ end -function Base.similar(G::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT} - return GradgenOperator{num_controls,GT,CGT}(similar(G.G), similar(G.control_deriv_ops)) +function Base.iterate(Ψ::GradVector, k = 1) + k > length(Ψ) && return nothing + return (Ψ[k], k + 1) end -function Base.eltype(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT} +# As for an `Operator`, we implement `similar` to return a standard `Array` +# because `GradgenOperator` does not `setindex!`, so it's arguable not a +# "mutable array"even if its components are mutable. +Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G)) + +Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O)) +Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(undef, dims) +Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} = + Array{S}(undef, dims) + +function Base.eltype( + ::Type{GradgenOperator{num_controls,GT,CGT}} +) where {num_controls,GT,CGT} return promote_type(eltype(GT), eltype(CGT)) end +function Base.getindex( + O::GradgenOperator{num_controls,GT,CGT}, + row::Int, + col::Int +) where {num_controls,GT,CGT} + T = eltype(O) + N, M = size(O.G) + L = num_controls + block_row = (row - 1) ÷ N + 1 + block_col = (col - 1) ÷ M + 1 + local_row = (row - 1) % N + 1 + local_col = (col - 1) % M + 1 + if block_row == block_col + return convert(T, O.G[local_row, local_col]) + elseif block_col == L + 1 && block_row <= L + return convert(T, O.control_deriv_ops[block_row][local_row, local_col]) + else + return zero(T) + end +end + +Base.length(O::GradgenOperator) = prod(size(O)) + +function Base.iterate(O::GradgenOperator, k = 1) + n = length(O) + k > n && return nothing + n_rows = size(O, 1) + i = (k - 1) % n_rows + 1 + j = (k - 1) ÷ n_rows + 1 + return (O[i, j], k + 1) +end + +function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T} + return eltype(T) +end + function Base.copyto!(dest::GradgenOperator, src::GradgenOperator) copyto!(dest.G, src.G) copyto!(dest.control_deriv_ops, src.control_deriv_ops) diff --git a/test/test_interface.jl b/test/test_interface.jl index 4cc054b..f2b80d6 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -3,10 +3,11 @@ using QuantumPropagators.Generators: hamiltonian using QuantumPropagators.Controls: get_controls using QuantumControlTestUtils.RandomObjects: random_matrix, random_state_vector using QuantumControl.Interfaces: check_generator -using QuantumPropagators.Interfaces: check_state -using QuantumGradientGenerators: GradGenerator, GradVector +using QuantumPropagators.Interfaces: + check_state, check_operator, supports_matrix_interface, supports_vector_interface +using QuantumGradientGenerators: GradGenerator, GradVector, GradgenOperator using StaticArrays: SVector, SMatrix -using LinearAlgebra: norm +using LinearAlgebra: norm, dot, mul!, I @testset "GradVector Interface" begin @@ -75,3 +76,139 @@ end @test check_generator(G̃_of_t; state = Ψ̃, tlist, for_gradient_optimization = false) end + + +@testset "GradgenOperator Matrix Interface" begin + + N = 5 + L = 2 + G = Matrix{ComplexF64}(I, N, N) + mu = [rand(ComplexF64, N, N) for _ = 1:L] + op = GradgenOperator{L,Matrix{ComplexF64},Matrix{ComplexF64}}(G, mu) + state = GradVector(rand(ComplexF64, N), L) + + # supports_matrix_interface reports true for matrix-backed GradgenOperator + @test supports_matrix_interface(typeof(op)) + + # check_operator passes the full matrix interface check including for_expval + @test check_operator(op; state, for_expval = true) + + # getindex is consistent with the dense Array representation + dense = Array(op) + @test all(op[i, j] ≈ dense[i, j] for i = 1:size(op, 1), j = 1:size(op, 2)) + + # length + @test length(op) == prod(size(op)) + + # iterate visits elements in column-major order, consistent with vec(Array(op)) + @test all(collect(op) .≈ vec(dense)) + + # 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, 1, 0) + Psi = GradVector(rand(ComplexF64, N), L) + Phi1 = GradVector(zeros(ComplexF64, N), L) + Phi2 = GradVector(zeros(ComplexF64, N), L) + mul!(Phi1, op, Psi) + mul!(Phi2, op, Psi, true, false) + @test norm(Phi1 - Phi2) < 1e-14 + + # 3-arg dot(Psi, op, Phi) matches dot(Psi, op * Phi) + Psi2 = GradVector(rand(ComplexF64, N), L) + @test dot(state, op, Psi2) ≈ dot(state, op * Psi2) + + # similar(op) returns a dense Array of the same eltype and size (matching Operator pattern) + op_sim = similar(op) + @test op_sim isa Array{eltype(op)} + @test size(op_sim) == size(op) + + # similar(op, S) returns a dense Array of type S with matching size + @test similar(op, Float64) isa Array{Float64} + @test size(similar(op, Float64)) == size(op) + + # similar(op, dims) returns a dense Array with given dims + @test similar(op, (3, 4)) isa Array{eltype(op)} + @test size(similar(op, (3, 4))) == (3, 4) + + # similar(op, S, dims) returns a dense Array of type S with given dims + @test similar(op, Float64, (3, 4)) isa Array{Float64} + @test size(similar(op, Float64, (3, 4))) == (3, 4) + +end + + +@testset "GradVector Vector Interface" begin + + N = 5 + L = 2 + Psi = rand(ComplexF64, N) + gradvec = GradVector(Psi, L) + + # supports_vector_interface is true for Vector-backed GradVector + @test supports_vector_interface(typeof(gradvec)) + + # check_state passes full vector interface check + @test check_state(gradvec) + + # size is 1D with total length + @test size(gradvec) == (N * (L + 1),) + @test size(gradvec) == (length(gradvec),) + + # getindex is consistent with convert_gradvec_to_dense layout: + # [grad_states[1]; grad_states[2]; ...; grad_states[L]; state] + dense = convert(Vector{ComplexF64}, gradvec) + @test all(gradvec[k] == dense[k] for k = 1:length(gradvec)) + + # iterate visits elements consistent with getindex + @test all(collect(gradvec) .== dense) + + # setindex! round-trips through getindex + gradvec2 = GradVector(copy(Psi), L) + for k = 1:length(gradvec2) + gradvec2[k] = gradvec[k] + end + @test all(gradvec2[k] == gradvec[k] for k = 1:length(gradvec)) + + # similar(gradvec, S) returns a mutable Vector{S} with same length + @test similar(gradvec, ComplexF32) isa Vector{ComplexF32} + @test length(similar(gradvec, ComplexF32)) == length(gradvec) + + # similar(gradvec, dims) returns a plain Array with same eltype and given dims + @test similar(gradvec, (3, 4)) isa Array{eltype(gradvec)} + @test size(similar(gradvec, (3, 4))) == (3, 4) + +end + + +@testset "GradVector Vector Interface (Static)" begin + + N = 5 + L = 2 + Psi = SVector{N,ComplexF64}(rand(ComplexF64, N)) + gradvec = GradVector(Psi, L) + + # SVector-backed GradVector: supports_vector_interface follows the component type + @test supports_vector_interface(typeof(gradvec)) + + # check_state passes (SVector is inplace=false, so setindex! is not checked) + @test check_state(gradvec) + + # getindex is consistent with the dense layout + dense = convert(Vector{ComplexF64}, gradvec) + @test all(gradvec[k] == dense[k] for k = 1:length(gradvec)) + +end + + +@testset "GradVector without Vector Interface" begin + + N = 5 + L = 2 + # Matrix is not an AbstractVector, so supports_vector_interface returns false + Psi = rand(ComplexF64, N, N) + gradvec = GradVector(Psi, L) + + @test !supports_vector_interface(typeof(gradvec)) + + # check_state still passes via the basic (non-vector) state interface + @test check_state(gradvec) + +end From db43ecaf6d2ebc38cd908450e1ca189b8c096822 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Sun, 22 Feb 2026 00:15:11 -0500 Subject: [PATCH 2/2] Guard linalg for objects not declaring matrix/vector interface --- src/linalg.jl | 213 ++++++++++++++++++++++++++++++++--------- test/test_interface.jl | 35 ++++++- 2 files changed, 203 insertions(+), 45 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 7336c37..f374ef9 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -80,48 +80,45 @@ function Base.copy(Ψ::GradVector{num_controls,T}) where {num_controls,T} end -function Base.length(Ψ::GradVector) +# === Vector interface for GradVector === +# +# The following methods are part of the vector interface and are only +# meaningful when `supports_vector_interface` is true for the state type T. +# Each method delegates to a private `_name(::Val{supports}, ...)` function: +# the Val{true} method contains the implementation, and the Val{false} method +# throws an error. + +function _length(::Val{true}, Ψ::GradVector) return length(Ψ.state) * (1 + length(Ψ.grad_states)) end - -function Base.size(Ψ::GradVector{num_controls,T}) where {num_controls,T} - return ((num_controls + 1) * length(Ψ.state),) +function _length(::Val{false}, Ψ::GradVector) + error("$(typeof(Ψ)) does not support the vector interface") end - -function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT} - return (num_controls + 1) .* size(O.G) +function Base.length(Ψ::T) where {T<:GradVector} + return _length(Val(supports_vector_interface(T)), Ψ) end -function Base.size( - O::GradgenOperator{num_controls,GT,CGT}, - dim::Integer -) where {num_controls,GT,CGT} - return (num_controls + 1) * size(O.G, dim) +function _size(::Val{true}, Ψ::GradVector{num_controls,T}) where {num_controls,T} + return ((num_controls + 1) * length(Ψ.state),) end - -function Base.similar(Ψ::GradVector{num_controls,T}) where {num_controls,T} - state_sim = similar(Ψ.state) - grad_states_sim = [similar(ϕ) for ϕ ∈ Ψ.grad_states] - return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim) +function _size(::Val{false}, Ψ::GradVector) + error("$(typeof(Ψ)) does not support the vector interface") end -Base.similar(Ψ::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ)) - -Base.similar(Ψ::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims) +function Base.size(Ψ::T) where {T<:GradVector} + return _size(Val(supports_vector_interface(T)), Ψ) +end -# These definitions of `similar` exist to make ExponentialUtilities happy, but -# it's not clear at all that `similar` with a custom shape really makes sense -Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} = - Matrix{T}(undef, dims...) -Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} = - Vector{T}(undef, dims[1]) - -function Base.getindex(Ψ::GradVector{num_controls,T}, k::Int) where {num_controls,T} +function _getindex( + ::Val{true}, + Ψ::GradVector{num_controls,T}, + k::Int +) where {num_controls,T} N = length(Ψ.state) L = num_controls block = (k - 1) ÷ N + 1 @@ -133,7 +130,21 @@ function Base.getindex(Ψ::GradVector{num_controls,T}, k::Int) where {num_contro end end -function Base.setindex!(Ψ::GradVector{num_controls,T}, v, k::Int) where {num_controls,T} +function _getindex(::Val{false}, Ψ::GradVector, k::Int) + error("$(typeof(Ψ)) does not support the vector interface") +end + +function Base.getindex(Ψ::T, k::Int) where {T<:GradVector} + return _getindex(Val(supports_vector_interface(T)), Ψ, k) +end + + +function _setindex!( + ::Val{true}, + Ψ::GradVector{num_controls,T}, + v, + k::Int +) where {num_controls,T} N = length(Ψ.state) L = num_controls block = (k - 1) ÷ N + 1 @@ -146,14 +157,105 @@ function Base.setindex!(Ψ::GradVector{num_controls,T}, v, k::Int) where {num_co return Ψ end -function Base.iterate(Ψ::GradVector, k = 1) +function _setindex!(::Val{false}, Ψ::GradVector, v, k::Int) + error("$(typeof(Ψ)) does not support the vector interface") +end + +function Base.setindex!(Ψ::T, v, k::Int) where {T<:GradVector} + return _setindex!(Val(supports_vector_interface(T)), Ψ, v, k) +end + + +function _iterate(::Val{true}, Ψ::GradVector, k) k > length(Ψ) && return nothing return (Ψ[k], k + 1) end +function _iterate(::Val{false}, Ψ::GradVector, k) + error("$(typeof(Ψ)) does not support the vector interface") +end + +function Base.iterate(Ψ::T, k = 1) where {T<:GradVector} + return _iterate(Val(supports_vector_interface(T)), Ψ, k) +end + + +function Base.similar(Ψ::GradVector{num_controls,T}) where {num_controls,T} + state_sim = similar(Ψ.state) + grad_states_sim = [similar(ϕ) for ϕ ∈ Ψ.grad_states] + return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim) +end + +# similar(Ψ, S) calls length(Ψ), which will error if !supports_vector_interface +Base.similar(Ψ::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ)) + +# similar(Ψ, dims) calls eltype(Ψ) but not length/size, so no vector interface needed +Base.similar(Ψ::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims) + +# These definitions of `similar` exist to make ExponentialUtilities happy, but +# it's not clear at all that `similar` with a custom shape really makes sense +Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} = + Matrix{T}(undef, dims...) + +Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} = + Vector{T}(undef, dims[1]) + + +function Base.fill!(Ψ::GradVector, v) + Base.fill!(Ψ.state, v) + for i = 1:length(Ψ.grad_states) + Base.fill!(Ψ.grad_states[i], v) + end + return Ψ +end + + +# === Matrix interface for GradgenOperator === +# +# The following methods are part of the matrix interface and are only +# meaningful when `supports_matrix_interface` is true for both component types. +# Each method delegates to a private `_name(::Val{supports}, ...)` function: +# the Val{true} method contains the implementation, and the Val{false} method +# throws an error. + +function _size( + ::Val{true}, + O::GradgenOperator{num_controls,GT,CGT} +) where {num_controls,GT,CGT} + return (num_controls + 1) .* size(O.G) +end + +function _size(::Val{false}, O::GradgenOperator) + error("$(typeof(O)) does not support the matrix interface") +end + +function Base.size(O::T) where {T<:GradgenOperator} + return _size(Val(supports_matrix_interface(T)), O) +end + + +function _size( + ::Val{true}, + O::GradgenOperator{num_controls,GT,CGT}, + dim::Integer +) where {num_controls,GT,CGT} + return (num_controls + 1) * size(O.G, dim) +end + +function _size(::Val{false}, O::GradgenOperator, dim::Integer) + error("$(typeof(O)) does not support the matrix interface") +end + +function Base.size(O::T, dim::Integer) where {T<:GradgenOperator} + return _size(Val(supports_matrix_interface(T)), O, dim) +end + + # As for an `Operator`, we implement `similar` to return a standard `Array` -# because `GradgenOperator` does not `setindex!`, so it's arguable not a -# "mutable array"even if its components are mutable. +# because `GradgenOperator` does not `setindex!`, so it's arguably not a +# "mutable array" even if its components are mutable. +# similar(O) and similar(O, S) call size(O), which will error if +# !supports_matrix_interface. The dims-based variants need no guard. Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G)) Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O)) @@ -161,13 +263,16 @@ Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(un Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} = Array{S}(undef, dims) + function Base.eltype( ::Type{GradgenOperator{num_controls,GT,CGT}} ) where {num_controls,GT,CGT} return promote_type(eltype(GT), eltype(CGT)) end -function Base.getindex( + +function _getindex( + ::Val{true}, O::GradgenOperator{num_controls,GT,CGT}, row::Int, col::Int @@ -188,9 +293,29 @@ function Base.getindex( end end -Base.length(O::GradgenOperator) = prod(size(O)) +function _getindex(::Val{false}, O::GradgenOperator, row::Int, col::Int) + error("$(typeof(O)) does not support the matrix interface") +end + +function Base.getindex(O::T, row::Int, col::Int) where {T<:GradgenOperator} + return _getindex(Val(supports_matrix_interface(T)), O, row, col) +end + + +function _length(::Val{true}, O::GradgenOperator) + return prod(size(O)) +end + +function _length(::Val{false}, O::GradgenOperator) + error("$(typeof(O)) does not support the matrix interface") +end + +function Base.length(O::T) where {T<:GradgenOperator} + return _length(Val(supports_matrix_interface(T)), O) +end + -function Base.iterate(O::GradgenOperator, k = 1) +function _iterate(::Val{true}, O::GradgenOperator, k) n = length(O) k > n && return nothing n_rows = size(O, 1) @@ -199,6 +324,15 @@ function Base.iterate(O::GradgenOperator, k = 1) return (O[i, j], k + 1) end +function _iterate(::Val{false}, O::GradgenOperator, k) + error("$(typeof(O)) does not support the matrix interface") +end + +function Base.iterate(O::T, k = 1) where {T<:GradgenOperator} + return _iterate(Val(supports_matrix_interface(T)), O, k) +end + + function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T} return eltype(T) end @@ -209,15 +343,6 @@ function Base.copyto!(dest::GradgenOperator, src::GradgenOperator) end -function Base.fill!(Ψ::GradVector, v) - Base.fill!(Ψ.state, v) - for i = 1:length(Ψ.grad_states) - Base.fill!(Ψ.grad_states[i], v) - end - return Ψ -end - - function Base.zero(Ψ::GradVector{num_controls,T}) where {num_controls,T} return GradVector{num_controls,T}(zero(Ψ.state), [zero(ϕ) for ϕ ∈ Ψ.grad_states]) end diff --git a/test/test_interface.jl b/test/test_interface.jl index f2b80d6..5f037bf 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -103,7 +103,7 @@ end # iterate visits elements in column-major order, consistent with vec(Array(op)) @test all(collect(op) .≈ vec(dense)) - # 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, 1, 0) + # 3-arg mul! agrees with 5-arg mul!(Phi, G, Psi, true, false) Psi = GradVector(rand(ComplexF64, N), L) Phi1 = GradVector(zeros(ComplexF64, N), L) Phi2 = GradVector(zeros(ComplexF64, N), L) @@ -211,4 +211,37 @@ end # check_state still passes via the basic (non-vector) state interface @test check_state(gradvec) + # Vector interface methods must throw an error when not supported + @test_throws "does not support the vector interface" gradvec[1] + @test_throws "does not support the vector interface" (gradvec[1] = 0.0) + @test_throws "does not support the vector interface" size(gradvec) + @test_throws "does not support the vector interface" length(gradvec) + @test_throws "does not support the vector interface" iterate(gradvec) + +end + + + +# A wrapper type with no supports_matrix_interface declaration (defaults to false) +struct NonMatrixOp + data::Matrix{ComplexF64} +end + +@testset "GradgenOperator without Matrix Interface" begin + + N = 5 + L = 2 + G = NonMatrixOp(rand(ComplexF64, N, N)) + mu = [NonMatrixOp(rand(ComplexF64, N, N)) for _ = 1:L] + op = GradgenOperator{L,NonMatrixOp,NonMatrixOp}(G, mu) + + @test !supports_matrix_interface(typeof(op)) + + # Matrix interface methods must throw an error when not supported + @test_throws "does not support the matrix interface" op[1, 1] + @test_throws "does not support the matrix interface" size(op) + @test_throws "does not support the matrix interface" size(op, 1) + @test_throws "does not support the matrix interface" length(op) + @test_throws "does not support the matrix interface" iterate(op) + end