Skip to content

Commit 815d3d8

Browse files
authored
Preserve input types for various rules (#89)
* Make Float32 stable for both arguments * revert :NaN change * remove float guard * Removed spurious NaN
1 parent 489e294 commit 815d3d8

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
8585
@define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) )
8686
@define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) )
8787
@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) )
88-
@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN
88+
@define_diffrule Base.ldexp(x, y) = :( oftype($x, exp2($y)) ), :NaN
8989

9090
@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) )
9191
@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) )
@@ -296,14 +296,14 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
296296
@define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x)
297297

298298
# binary
299-
@define_diffrule LogExpFunctions.xlogy(x, y) =
299+
@define_diffrule LogExpFunctions.xlogy(x, y) =
300300
:(log($y)),
301301
:(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z)
302302
@define_diffrule LogExpFunctions.logaddexp(x, y) =
303303
:(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y)))
304304
@define_diffrule LogExpFunctions.logsubexp(x, y) =
305305
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)),
306306
:(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z))
307-
@define_diffrule LogExpFunctions.xlog1py(x, y) =
307+
@define_diffrule LogExpFunctions.xlog1py(x, y) =
308308
:(log1p($y)),
309309
:(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z)

0 commit comments

Comments
 (0)