Skip to content

Commit

Permalink
More fine grained bundle preservation (rust-lang#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 12, 2022
1 parent 7d33211 commit 680ed95
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ class GradientUtils : public CacheUtility {
SmallVector<OperandBundleDef, 2> OrigDefs;
orig->getOperandBundlesAsDefs(OrigDefs);
SmallVector<OperandBundleDef, 2> Defs;
bool anyPrimal = false;
bool anyShadow = false;
for (auto ty : types) {
if (ty == ValueType::Primal || ty == ValueType::Both)
anyPrimal = true;
if (ty == ValueType::Shadow || ty == ValueType::Both)
anyShadow = true;
}
for (auto bund : OrigDefs) {
// Only handle jl_roots tag (for now).
if (bund.getTag() != "jl_roots") {
Expand All @@ -219,11 +227,13 @@ class GradientUtils : public CacheUtility {
// primals and shadows
// assert(bund.inputs().size() == types.size());
for (auto inp : bund.inputs()) {
Value *newv = getNewFromOriginal(inp);
if (lookup)
newv = lookupM(newv, Builder2, available);
bunds.push_back(newv);
if (!isConstantValue(inp)) {
if (anyPrimal) {
Value *newv = getNewFromOriginal(inp);
if (lookup)
newv = lookupM(newv, Builder2, available);
bunds.push_back(newv);
}
if (anyShadow && !isConstantValue(inp)) {
Value *shadow = invertPointerM(inp, Builder2);
if (lookup)
shadow = lookupM(shadow, Builder2);
Expand Down

0 comments on commit 680ed95

Please sign in to comment.