forked from JuliaDiff/ChainRulesCore.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaccumulation.jl
More file actions
92 lines (80 loc) · 2.88 KB
/
accumulation.jl
File metadata and controls
92 lines (80 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
add!!(x, y)
Returns `x+y`, potentially mutating `x` in-place to hold this value.
This avoids allocations when `x` can be mutated in this way.
"""
add!!(x, y) = x + y
"""
add!!(x, t::InplacableThunk)
The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call
`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place.
"""
function add!!(x, t::InplaceableThunk)
return if is_inplaceable_destination(x)
if !debug_mode()
t.add!(x)
else
debug_add!(x, t)
end
else
x + t
end
end
add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y))
function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
return if is_inplaceable_destination(x)
x .+= y
else
x + y
end
end
"""
is_inplaceable_destination(x) -> Bool
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
differential.
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
`is_inplaceable_destination` if and only if their parent array is.
Other types should overload this, as it defaults to `false`.
"""
is_inplaceable_destination(::Any) = false
is_inplaceable_destination(::Array) = true
is_inplaceable_destination(::SparseVector) = true
is_inplaceable_destination(::SparseMatrixCSC) = true
is_inplaceable_destination(::BitArray) = true
function is_inplaceable_destination(x::AbstractArray)
p = parent(x)
p === x && return false # no parent
# basically all wrapper types delegate `setindex!` to their `parent` after some
# processing and so are mutable if their `parent` is.
return is_inplaceable_destination(p)
end
# Hermitian and Symmetric are too fussy to deal with right now
# https://github.com/JuliaLang/julia/issues/38056
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false
function debug_add!(accumuland, t::InplaceableThunk)
returned_value = t.add!(accumuland)
if returned_value !== accumuland
throw(BadInplaceException(t, accumuland, returned_value))
end
return returned_value
end
struct BadInplaceException <: Exception
ithunk::InplaceableThunk
accumuland
returned_value
end
function Base.showerror(io::IO, err::BadInplaceException)
println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.")
println(io, "ithunk = $(err.ithunk)")
println(io, "accumuland = $(err.accumuland)")
println(io, "returned_value = $(err.returned_value)")
if err.accumuland == err.returned_value
println(
io,
"Which in this case happenned to be equal. But they are not the same object.",
)
end
end