diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index b3fc4d4448c..2973c4b84cd 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -144,6 +144,11 @@ def CFAdd : SubRoutine<(Op (Op $re1, $im1):$z1, (Op $re2, $im2):$z2), (FAdd $re1, $re2), (FAdd $im1, $im2) )>; +def CFSub : SubRoutine<(Op (Op $re1, $im1):$z1, (Op $re2, $im2):$z2), + (ArrayRet + (FSub $re1, $re2), + (FSub $im1, $im2) + )>; def CFMul_splat : SubRoutine<(Op $re1, $im1, $re2, $im2), (ArrayRet @@ -666,6 +671,17 @@ def : CallPattern<(Op $n, $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $n, $x), + ["cmplx_jn","cmplx_yn"], + [ + (InactiveArg), + // Reverse mode needs to return the conjugate + (CFMul (DiffeRet), (Conj (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x))))) + ], + (CFMul (Shadow $x), (CFMul (ConstantCFP<"0.5", "0"> $x), (CFSub (Call<(SameFunc), [ReadNone,NoUnwind]> (FSub $n, (ConstantFP<"1"> $n)), $x), (Call<(SameFunc), [ReadNone,NoUnwind]> (FAdd $n, (ConstantFP<"1"> $n)), $x)))), + [ReadNone, NoUnwind] + >; + def : CallPattern<(Op $x), ["erf","erff","erfl"], [