Skip to content

Commit ba68dca

Browse files
authored
support constructors with non_differentiable (#243)
* support constructors with non_differentiable * Test at=non_differentiable on constructors * Make demo ADs not try to work on constructors * also test negative
1 parent 713fd5e commit ba68dca

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

test/demos/forwarddiffzero.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111
# Define the AD
1212

1313
# Note that we never directly define Dual Number Arithmetic on Dual numbers
14-
# instead it is automatically defined from the `frules`
14+
# instead it is automatically defined from the `frules`
1515
struct Dual <: Real
1616
primal::Float64
1717
partial::Float64
@@ -30,7 +30,8 @@ Base.to_power_type(x::Dual) = x
3030
function define_dual_overload(sig)
3131
sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls
3232
opT, argTs = Iterators.peel(sig.parameters)
33-
fieldcount(opT) == 0 || return # not handling functors
33+
opT isa Type{<:Type} && return # not handling constructors
34+
fieldcount(opT) == 0 || return # not handling functors
3435
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
3536

3637
N = length(sig.parameters) - 1 # skip the op
@@ -65,7 +66,7 @@ function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
6566
end
6667

6768
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
68-
refresh_rules();
69+
refresh_rules();
6970

7071
@testset "ForwardDiffZero" begin
7172
foo(x) = x + x

test/demos/reversediffzero.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ Base.to_power_type(x::Tracked) = x
5959
function define_tracked_overload(sig)
6060
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll
6161
opT, argTs = Iterators.peel(sig.parameters)
62-
fieldcount(opT) == 0 || return # not handling functors
62+
opT isa Type{<:Type} && return # not handling constructors
63+
fieldcount(opT) == 0 || return # not handling functors
6364
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
6465

6566
N = length(sig.parameters) - 1 # skip the op
@@ -116,7 +117,7 @@ function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
116117
end
117118

118119
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
119-
refresh_rules();
120+
refresh_rules();
120121

121122
@testset "ReversedDiffZero" begin
122123
foo(x) = x + x

0 commit comments

Comments
 (0)