In an ideal world, a user would test an rrule by writing something like the following, and have it work all of the time:
By work all of the time, I mean that the tests that we want to run to determine the correctness of an rrule implementation are always run successfully, provided that the function is something that know how to test (broadly speaking, the output is deterministic given the input), and for any input type that is either
- a primitive that we know about (
Real, Array, etc), or
- a composite type.
It's important that this works automatically because we want people to be testing their code using CRTU, and people like to define new types (including new AbstractArrays) and new functions. Unfortunately, I don't believe it's possible to automate in all cases, but the way in which it fails (AFAICT) is very specific, and I think we can document it and make it easy to resolve for users.
Roughly speaking, list of the functionality that always needs to always work in order to achieve this is
- to_vec
- to_vec_tangent (a new function)
- rand_tangent
- test_approx
to_vec, to_vec_tangent and rand_tangent can be made to "always work", but test_approx occassionally has a quirk that I don't believe that we can automate.
The outcome is the following proposals:
- remove all (or at least most)
to_vec implementations in favour of the generic to_vec implementation of isstructtype types, and necessary to_vec implementations for isprimitivetype types,
- introduce a
to_vec_tangent (better name welcome) function, which is like to_vec, but the closure returned returns a tangent rather than a primal,
- add a function called
remove_junk_data, or something similar, which applies to primals, and returns another object which contains only the bits the primal relevant for defining isapprox and whenever we test rules, we test the composition of remove_junk_data and the function being tested, rather than just the function. This enables us to define test_approx in a really generic manner.
I'll explain throughout this issue why I believe these are sensible proposals, and how they resolve things.
Additionally, while this proposal is independent from other proposed changes, it clearly favours a structural view of the world because I'm interested in automating things. See JuliaDiff/ChainRulesCore.jl#449 for a proposal for how we can do this without sacrificing usability, and how this leads to a precise definition for natural tangents.
I would be really interested to know if anyone thinks I've obviously missed something, or whether this sounds about right.
edit: I completely neglected constraint-related problems (eg. if the tangent provided to FiniteDifferences needs to represent a positive definite matrix for some reason). AFAICT the things discussed are essentially orthogonal to the constraint problems though.
edit2: note: undefined references are not fun. For example, perfectly well-defined Dict objects can contain undefined memory. I think this probably comes under the heading of "junk" data, but is seems to cause problems for to_vec as it's currently defined. I wonder whether it could be generalised?
Example 1: Diagonal size mismatch
Consider testing
Let the output (co)tangent be
x = Diagonal(randn(2))
dx = Tangent{typeof(x)}(diag=randn(2))
then
FiniteDifferences.j′vp(central_fdm(5, 1), f, dx, x)
produces the error:
ERROR: DimensionMismatch("second dimension of A, 4, does not match length of x, 2")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:530
[2] mul!
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:97 [inlined]
[3] mul!
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *(transA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:87
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Tangent{Diagonal{Float64, Vector{Float64}}, NamedTuple{(:diag,), Tuple{Vector{Float64}}}}, x::Diagonal{Float64, Vector{Float64}})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:73
[7] top-level scope
@ REPL[24]:1
Why is this example a problem?
Firstly, we presently require that rules accept either a natural or structural tangent. Due to the above, it's not currently possible to test functions which output a Diagonal with a Tangent tangent.
Secondly, there exist Diagonal matrices whose tangent cannot be represented by a Diagonal. Specifically, any Diagonal whose diag field doesn't provide a way to produce an AbstractVector as its tangent (i.e. for whatever reason, lacks a natural tangent). Consequently, in order for our testing facilities to handle any type, they must be able to work with structural tangents.
Finally, our current imlementations special case to_vec for lots of different arrays (Diagonal, Symmetric etc). This is a problem in itself, but moreover we're never entirely sure what the right thing to do is when we encounter a new array.
How to fix this problem
Remove the specialised to_vec methods for Diagonal and other struct AbstractArrays (UpperTriangular, Symmetric, etc), and instead just rely on the generic to_vec operation for structs.
Doing this immediately means that we can to_vec anything that is either
- a primitive that we've defined
to_vec on, or
- any
struct or mutable struct.
This solution brings into focus a problem that we're currently solving on an ad-hoc basis in to_vec: "junk" data in e.g. the lower triangle of a Symmetric can wind up being used in approximate equality checks (and could in principle introduce non-determinism in an otherwise deterministic function, although I've yet to find an example of this in the wild), which makes no sense. We'll address this later.
Example 2: to_vec gives the wrong type sometimes
to_vec only knows about primals -- it knows nothing about tangents. The reason for this is because it was written when we also knew nowhere near enough about tangents, in particular for arrays. The particular problem is in the call to vec_to_x on this line in j′vp. It attempts to convert a "flat" vector representation of a cotangent into a primal. While this works fine in some cases (a surprisingly large number, given how much mileage we've gotten out of to_vec over the years), we know that it doesn't work for all types.
Once you've removed the to_vec implementations for the various concrete subtypes of AbstractArray, you'll find that
FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1]
yields
2×2 Diagonal{Float64, Vector{Float64}}:
9.97302 ⋅
⋅ -0.329386
Why is this is a problem
Having the pullback for identity return anything other than whatever cotangent it is provided seems highly undesirable to me, so I'm going to assume that our rrule for identity does just that. If that is the case, then the cotangent returned the pullback produced by that rrule will be a Tangent if the input is a Tangent , not a Diagonal. This means that our current implementation is incorrect. While this particular example seems reasonably benign, to my mind it's not correct. However, even if you believe it's correct, it's clearly only correct because Diagonal{Float64, Vector{Float64}}s happens to have nice natural tangents that happen to be produced by to_vec, rather than by design.
A more obviously incorrect / plainly-uninterpretable example is a Symmetric -- the from_vec output from to_vec(::Symmetric) will produce a Symmetric with an uplo field that is a Char. Since a Char isn't an appropriate tangent for a Char (it should be a NoTangent), this is plainly nonsensical if the goal is to obtain a tangent. If you wound up comparing between this representation of the tangent and a Tangent output from AD, you would need to compare a NoTangent with this Char, which under any sensible definition would fail (I can't imagine a world in which I would wish to reside in which NoTangent is considered equal to a Char).
How to fix this problem
Introduce another function to_vec_tangent (better name would be nice) which returns a closure that always returns an appropriate tangent representation (primtive for primitives, structural for composites). This would require roughly the same level of implementation effort as to_vec, and would mirror its structure almost entirely (specific methods for primitives, generic method for all isstructtype types).
Example 3: Propagation of Junk Data
Consider
x = Symmetric(randn(2, 2))
dx = Tangent{typeof(x)}(data=randn(2, 2))
FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1].data
yields
2×2 Matrix{Float64}:
-0.180472 -0.793039
0.740994 0.900423
Note that the data field is the relevant bit of the output from j′vp here, because the consistent / correct interpretation of the thing that FiniteDifferences outputs is a Tangent, not a Symmetric, as discussed in the previous example. Observe that the lower triangle (element (2, 1)) will be used when test_approx is computed, because the generic definition of test_approx doesn't know about the specific semantics of Symmetric. Since the standard libary makes no promises about the lower triangle of a Symmetric, it seems to me intuitive that we shouldn't have to worry about it in our gradient definitions. I'm happy to expand on this, but there's a good example here.
The solution I believe is best is to ensure that the gradient w.r.t. irrelevant elements is always 0 by always testing
x -> remove_junk_data(f(x))
rather than just f. The function remove_junk_data would be defined such that it doesn't propagate any junk data (data which isn't relevant for equality computations). The implementations that I have so far are things like:
remove_junk_data(x::Number) = x
remove_junk_data(x::StridedArray) = map(remove_junk_data, x)
remove_junk_data(x::Symmetric{T, <:StridedArray{T}}) where {T} = collect(x)
remove_junk_data(x::UpperTriangular{T, <:StridedArray{T}}) where {T} = collect(x)
remove_junk_data(x::LowerTriangular{T, <:StridedArray{T}}) where {T} = collect(x)
function remove_junk_data(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
return map(remove_junk_data, fieldnames(T))
end
Another option one could consider is trying to define equality properly on Tangents. This isn't general though because e.g. the data field of a Tangent{Symmetric} might itself be a Tangent, which doesn't have a conception of its own lower triangle. The benefit of composing with remove_junk_data is that we get to operate on primal types, whose semantics everyone is familiar with (the data field of a Symmetric definitely does know about triangles because its an AbstractArray and has getindex defined).
So we can instruct type authors (ourselves for stdlib types) that if their types have any data that's essentially "junk" they must define a method of remove_junk_data, and accept that we'll have to expend some extra computation internally to differentiate remove_junk_data when testing (can probably be optimised away in most cases, since its the identity function in most cases).
Note that the generic fallback for composites means that we'll get overly restrictive tests by default, and type-authors have to opt-in to say that some bits of their type aren't important. This seems like the desirable way around to me -- I'd rather have tests yelling at me when they ought not to be, than them to fail to yell at me when they should.
Outcomes
Assuming that this pans out, this is a win-win for developers and users.
Developers get simpler, more robust, more straightforward to understand code with fewer edge cases -- the edge cases that remain have clear semantics and it's clear why they're necessary.
Users benefit from more predictable and reliable infrastructure.
The issue with this proposal is that it requires structural tangents to actually be taken seriously by everyone. Again, see JuliaDiff/ChainRulesCore.jl#449 for a discussion of how to make this more straightforward for all involved.
In an ideal world, a user would test an
rruleby writing something like the following, and have it work all of the time:By work all of the time, I mean that the tests that we want to run to determine the correctness of an
rruleimplementation are always run successfully, provided that the function is something that know how to test (broadly speaking, the output is deterministic given the input), and for any input type that is eitherReal,Array, etc), orIt's important that this works automatically because we want people to be testing their code using CRTU, and people like to define new types (including new
AbstractArrays) and new functions. Unfortunately, I don't believe it's possible to automate in all cases, but the way in which it fails (AFAICT) is very specific, and I think we can document it and make it easy to resolve for users.Roughly speaking, list of the functionality that always needs to always work in order to achieve this is
to_vec,to_vec_tangentandrand_tangentcan be made to "always work", buttest_approxoccassionally has a quirk that I don't believe that we can automate.The outcome is the following proposals:
to_vecimplementations in favour of the genericto_vecimplementation ofisstructtypetypes, and necessaryto_vecimplementations forisprimitivetypetypes,to_vec_tangent(better name welcome) function, which is liketo_vec, but the closure returned returns a tangent rather than a primal,remove_junk_data, or something similar, which applies to primals, and returns another object which contains only the bits the primal relevant for definingisapproxand whenever we test rules, we test the composition ofremove_junk_dataand the function being tested, rather than just the function. This enables us to definetest_approxin a really generic manner.I'll explain throughout this issue why I believe these are sensible proposals, and how they resolve things.
Additionally, while this proposal is independent from other proposed changes, it clearly favours a structural view of the world because I'm interested in automating things. See JuliaDiff/ChainRulesCore.jl#449 for a proposal for how we can do this without sacrificing usability, and how this leads to a precise definition for natural tangents.
I would be really interested to know if anyone thinks I've obviously missed something, or whether this sounds about right.
edit: I completely neglected constraint-related problems (eg. if the tangent provided to
FiniteDifferencesneeds to represent a positive definite matrix for some reason). AFAICT the things discussed are essentially orthogonal to the constraint problems though.edit2: note: undefined references are not fun. For example, perfectly well-defined
Dictobjects can contain undefined memory. I think this probably comes under the heading of "junk" data, but is seems to cause problems forto_vecas it's currently defined. I wonder whether it could be generalised?Example 1:
Diagonalsize mismatchConsider testing
Let the output (co)tangent be
then
produces the error:
Why is this example a problem?
Firstly, we presently require that rules accept either a natural or structural tangent. Due to the above, it's not currently possible to test functions which output a
Diagonalwith aTangenttangent.Secondly, there exist
Diagonalmatrices whose tangent cannot be represented by aDiagonal. Specifically, anyDiagonalwhosediagfield doesn't provide a way to produce anAbstractVectoras its tangent (i.e. for whatever reason, lacks a natural tangent). Consequently, in order for our testing facilities to handle any type, they must be able to work with structural tangents.Finally, our current imlementations special case
to_vecfor lots of different arrays (Diagonal,Symmetricetc). This is a problem in itself, but moreover we're never entirely sure what the right thing to do is when we encounter a new array.How to fix this problem
Remove the specialised
to_vecmethods forDiagonaland otherstructAbstractArrays (UpperTriangular,Symmetric, etc), and instead just rely on the genericto_vecoperation forstructs.Doing this immediately means that we can
to_vecanything that is eitherto_vecon, orstructormutable struct.This solution brings into focus a problem that we're currently solving on an ad-hoc basis in
to_vec: "junk" data in e.g. the lower triangle of aSymmetriccan wind up being used in approximate equality checks (and could in principle introduce non-determinism in an otherwise deterministic function, although I've yet to find an example of this in the wild), which makes no sense. We'll address this later.Example 2:
to_vecgives the wrong type sometimesto_veconly knows about primals -- it knows nothing about tangents. The reason for this is because it was written when we also knew nowhere near enough about tangents, in particular for arrays. The particular problem is in the call tovec_to_xon this line inj′vp. It attempts to convert a "flat" vector representation of a cotangent into a primal. While this works fine in some cases (a surprisingly large number, given how much mileage we've gotten out ofto_vecover the years), we know that it doesn't work for all types.Once you've removed the
to_vecimplementations for the various concrete subtypes ofAbstractArray, you'll find thatyields
Why is this is a problem
Having the pullback for
identityreturn anything other than whatever cotangent it is provided seems highly undesirable to me, so I'm going to assume that ourrruleforidentitydoes just that. If that is the case, then the cotangent returned the pullback produced by thatrrulewill be aTangentif the input is aTangent, not aDiagonal. This means that our current implementation is incorrect. While this particular example seems reasonably benign, to my mind it's not correct. However, even if you believe it's correct, it's clearly only correct becauseDiagonal{Float64, Vector{Float64}}s happens to have nice natural tangents that happen to be produced byto_vec, rather than by design.A more obviously incorrect / plainly-uninterpretable example is a
Symmetric-- thefrom_vecoutput fromto_vec(::Symmetric)will produce aSymmetricwith anuplofield that is aChar. Since aCharisn't an appropriate tangent for aChar(it should be aNoTangent), this is plainly nonsensical if the goal is to obtain a tangent. If you wound up comparing between this representation of the tangent and aTangentoutput from AD, you would need to compare aNoTangentwith thisChar, which under any sensible definition would fail (I can't imagine a world in which I would wish to reside in whichNoTangentis considered equal to aChar).How to fix this problem
Introduce another function
to_vec_tangent(better name would be nice) which returns a closure that always returns an appropriate tangent representation (primtive for primitives, structural for composites). This would require roughly the same level of implementation effort asto_vec, and would mirror its structure almost entirely (specific methods for primitives, generic method for allisstructtypetypes).Example 3: Propagation of Junk Data
Consider
yields
Note that the
datafield is the relevant bit of the output fromj′vphere, because the consistent / correct interpretation of the thing thatFiniteDifferencesoutputs is aTangent, not aSymmetric, as discussed in the previous example. Observe that the lower triangle (element(2, 1)) will be used whentest_approxis computed, because the generic definition oftest_approxdoesn't know about the specific semantics ofSymmetric. Since the standard libary makes no promises about the lower triangle of aSymmetric, it seems to me intuitive that we shouldn't have to worry about it in our gradient definitions. I'm happy to expand on this, but there's a good example here.The solution I believe is best is to ensure that the gradient w.r.t. irrelevant elements is always
0by always testingrather than just
f. The functionremove_junk_datawould be defined such that it doesn't propagate any junk data (data which isn't relevant for equality computations). The implementations that I have so far are things like:Another option one could consider is trying to define equality properly on
Tangents. This isn't general though because e.g. thedatafield of aTangent{Symmetric}might itself be aTangent, which doesn't have a conception of its own lower triangle. The benefit of composing withremove_junk_datais that we get to operate on primal types, whose semantics everyone is familiar with (thedatafield of aSymmetricdefinitely does know about triangles because its anAbstractArrayand hasgetindexdefined).So we can instruct type authors (ourselves for stdlib types) that if their types have any data that's essentially "junk" they must define a method of
remove_junk_data, and accept that we'll have to expend some extra computation internally to differentiateremove_junk_datawhen testing (can probably be optimised away in most cases, since its the identity function in most cases).Note that the generic fallback for composites means that we'll get overly restrictive tests by default, and type-authors have to opt-in to say that some bits of their type aren't important. This seems like the desirable way around to me -- I'd rather have tests yelling at me when they ought not to be, than them to fail to yell at me when they should.
Outcomes
Assuming that this pans out, this is a win-win for developers and users.
Developers get simpler, more robust, more straightforward to understand code with fewer edge cases -- the edge cases that remain have clear semantics and it's clear why they're necessary.
Users benefit from more predictable and reliable infrastructure.
The issue with this proposal is that it requires structural tangents to actually be taken seriously by everyone. Again, see JuliaDiff/ChainRulesCore.jl#449 for a discussion of how to make this more straightforward for all involved.