Skip to content

Commit

Permalink
Support "key" argument on min() and max() builtins (#505)
Browse files Browse the repository at this point in the history
* Support "key" argument on min() and max() builtins

* Delay overload selection when arguments are not known (delayed dispatch)

* Delay 'is None' for 'Optional[T]' until type is known

* Fix union overload selection

* Add static string slicing

* Fix itertools.accumulate

* Fix list comprehension optimization ( minitech:imports-in-list-comprehensions )

* Fix match or patterns

* Fix tests and faulty static tuple issue

* Fix OpenMP reductions with new min/max functions

* Fix domination of dominated bindings; Fix hasattr overloads; Fix arg=None handling

* Fix empty return handling; Mark generators with an attribute

* Fix #487

* Fix test

* Fix IR pass

---------

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
  • Loading branch information
arshajii and inumanag authored Jan 13, 2024
1 parent 8a0064a commit d23c8c7
Show file tree
Hide file tree
Showing 16 changed files with 334 additions and 98 deletions.
23 changes: 10 additions & 13 deletions codon/cir/transform/parallel/openmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,15 @@ struct Reduction {
case Kind::XOR:
result = *lhs ^ *arg;
break;
case Kind::MIN: {
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("min", {tup->getType()}, {}, builtinModule);
seqassertn(fn, "min function not found");
result = util::call(fn, {tup});
break;
}
case Kind::MIN:
case Kind::MAX: {
auto name = (kind == Kind::MIN ? "min" : "max");
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("max", {tup->getType()}, {}, builtinModule);
seqassertn(fn, "max function not found");
result = util::call(fn, {tup});
auto *none = (*M->getNoneType())();
auto *fn = M->getOrRealizeFunc(name, {tup->getType(), none->getType()}, {},
builtinModule);
seqassertn(fn, "{} function not found", name);
result = util::call(fn, {tup, none});
break;
}
default:
Expand Down Expand Up @@ -432,6 +429,7 @@ struct ReductionIdentifier : public util::Operator {
auto *ptrType = cast<types::PointerType>(shared->getType());
seqassertn(ptrType, "expected shared var to be of pointer type");
auto *type = ptrType->getBase();
auto *noneType = M->getOptionalType(M->getNoneType());

// double-check the call
if (!util::isCallOf(v, Module::SETITEM_MAGIC_NAME,
Expand All @@ -454,7 +452,8 @@ struct ReductionIdentifier : public util::Operator {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
continue;
} else {
if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type})}, type,
if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type}), noneType},
type,
/*method=*/false))
continue;
}
Expand Down Expand Up @@ -1183,9 +1182,7 @@ struct GPULoopBodyStubReplacer : public util::Operator {

std::vector<Value *> newArgs;
for (auto *arg : *replacement) {
// std::cout << "A: " << *arg << std::endl;
if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) {
// std::cout << "(loop var)" << std::endl;
newArgs.push_back(idx);
} else {
newArgs.push_back(util::tupleGet(args, next++));
Expand Down
12 changes: 10 additions & 2 deletions codon/cir/transform/pythonic/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ struct GeneratorSumTransformer : public util::Operator {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
if (v->getValue()) {
v->replaceAll(util::series(v->getValue(), newReturn));
} else {
v->replaceAll(newReturn);
}
}

void handle(YieldInInstr *v) override { valid = false; }
Expand Down Expand Up @@ -97,7 +101,11 @@ struct GeneratorAnyAllTransformer : public util::Operator {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(!any));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
if (v->getValue()) {
v->replaceAll(util::series(v->getValue(), newReturn));
} else {
v->replaceAll(newReturn);
}
}

void handle(YieldInInstr *v) override { valid = false; }
Expand Down
1 change: 1 addition & 0 deletions codon/parser/ast/stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ const std::string Attr::CVarArg = ".__vararg__";
const std::string Attr::Method = ".__method__";
const std::string Attr::Capture = ".__capture__";
const std::string Attr::HasSelf = ".__hasself__";
const std::string Attr::IsGenerator = ".__generator__";
const std::string Attr::Extend = "extend";
const std::string Attr::Tuple = "tuple";
const std::string Attr::Test = "std.internal.attributes.test";
Expand Down
1 change: 1 addition & 0 deletions codon/parser/ast/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ struct Attr {
const static std::string Method;
const static std::string Capture;
const static std::string HasSelf;
const static std::string IsGenerator;
// Class attributes
const static std::string Extend;
const static std::string Tuple;
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/visitors/simplify/access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void SimplifyVisitor::visit(IdExpr *expr) {
if (!checked) {
// Prepend access with __internal__.undef([var]__used__, "[var name]")
auto checkStmt = N<ExprStmt>(N<CallExpr>(
N<DotExpr>("__internal__", "undef"),
N<IdExpr>("__internal__.undef"),
N<IdExpr>(fmt::format("{}.__used__", val->canonicalName)),
N<StringExpr>(ctx->cache->reverseIdentifierLookup[val->canonicalName])));
if (!ctx->isConditionalExpr) {
Expand Down
2 changes: 2 additions & 0 deletions codon/parser/visitors/simplify/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace codon::ast {
void SimplifyVisitor::visit(YieldExpr *expr) {
if (!ctx->inFunction())
E(Error::FN_OUTSIDE_ERROR, expr, "yield");
ctx->getBase()->attributes->set(Attr::IsGenerator);
}

/// Transform lambdas. Capture outer expressions.
Expand All @@ -45,6 +46,7 @@ void SimplifyVisitor::visit(YieldStmt *stmt) {
if (!ctx->inFunction())
E(Error::FN_OUTSIDE_ERROR, stmt, "yield");
transform(stmt->expr);
ctx->getBase()->attributes->set(Attr::IsGenerator);
}

/// Transform `yield from` statements.
Expand Down
32 changes: 20 additions & 12 deletions codon/parser/visitors/typecheck/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,27 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {
if (auto changed = in(ctx->cache->replacements, lhs)) {
while (auto s = in(ctx->cache->replacements, lhs))
lhs = changed->first, changed = s;
if (stmt->rhs && changed->second) {
// Mark the dominating binding as used: `var.__used__ = True`
auto u =
N<AssignStmt>(N<IdExpr>(fmt::format("{}.__used__", lhs)), N<BoolExpr>(true));
u->setUpdate();
prependStmts->push_back(transform(u));
} else if (changed->second && !stmt->rhs) {
// This assignment was a declaration only. Just mark the dominating binding as
// used: `var.__used__ = True`
stmt->lhs = N<IdExpr>(fmt::format("{}.__used__", lhs));
stmt->rhs = N<BoolExpr>(true);
if (changed->second) { // has __used__ binding
if (stmt->rhs) {
// Mark the dominating binding as used: `var.__used__ = True`
auto u = N<AssignStmt>(N<IdExpr>(fmt::format("{}.__used__", lhs)),
N<BoolExpr>(true));
u->setUpdate();
prependStmts->push_back(transform(u));
} else {
// This assignment was a declaration only. Just mark the dominating binding as
// used: `var.__used__ = True`
stmt->lhs = N<IdExpr>(fmt::format("{}.__used__", lhs));
stmt->rhs = N<BoolExpr>(true);
}
}
seqassert(stmt->rhs, "bad domination statement: '{}'", stmt->toString());

if (endswith(lhs, ".__used__") || !stmt->rhs) {
// unneeded declaration (unnecessary used or binding)
resultStmt = transform(N<SuiteStmt>());
return;
}

// Change this to the update and follow the update logic
stmt->setUpdate();
transformUpdate(stmt);
Expand Down
80 changes: 50 additions & 30 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ void TypecheckVisitor::visit(EllipsisExpr *expr) {
/// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments ,
/// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details.
void TypecheckVisitor::visit(CallExpr *expr) {
if (expr->expr->isId("__internal__.undef") && expr->args.size() == 2 &&
expr->args[0].value->getId()) {
auto val = expr->args[0].value->getId()->value;
val = val.substr(0, val.size() - 9);
if (auto changed = in(ctx->cache->replacements, val)) {
while (auto s = in(ctx->cache->replacements, val))
val = changed->first, changed = s;
if (!changed->second) {
// TODO: add no-op expr
resultExpr = transform(N<BoolExpr>(false));
return;
}
}
}

// Transform and expand arguments. Return early if it cannot be done yet
if (!transformCallArgs(expr->args))
return;
Expand Down Expand Up @@ -319,7 +334,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
}
ExprPtr e = N<TupleExpr>(extra);
e->setAttr(ExprAttr::StarArgument);
if (!expr->expr->isId("hasattr:0"))
if (!expr->expr->isId("hasattr"))
e = transform(e);
if (partial) {
part.args = e;
Expand Down Expand Up @@ -373,8 +388,16 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
E(Error::CALL_RECURSIVE_DEFAULT, expr,
ctx->cache->rev(calleeFn->ast->args[si].name));
ctx->defaultCallDepth.insert(es);
args.push_back(
{realName, transform(clone(calleeFn->ast->args[si].defaultValue))});

if (calleeFn->ast->args[si].defaultValue->getNone() &&
!calleeFn->ast->args[si].type) {
args.push_back(
{realName, transform(N<CallExpr>(N<InstantiateExpr>(
N<IdExpr>("Optional"), N<IdExpr>("NoneType"))))});
} else {
args.push_back(
{realName, transform(clone(calleeFn->ast->args[si].defaultValue))});
}
ctx->defaultCallDepth.erase(es);
}
} else {
Expand Down Expand Up @@ -562,7 +585,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
return {true, transformIsInstance(expr)};
} else if (val == "staticlen") {
return {true, transformStaticLen(expr)};
} else if (startswith(val, "hasattr:")) {
} else if (val == "hasattr") {
return {true, transformHasAttr(expr)};
} else if (val == "getattr") {
return {true, transformGetAttr(expr)};
Expand Down Expand Up @@ -812,38 +835,35 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
auto typ = expr->args[0].value->getType()->getClass();
if (!typ)
return nullptr;

auto member = expr->expr->type->getFunc()
->funcGenerics[0]
.type->getStatic()
->evaluate()
.getString();
std::vector<std::pair<std::string, TypePtr>> args{{"", typ}};
if (expr->expr->isId("hasattr:0")) {
// Case: the first hasattr overload allows passing argument types via *args
auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple");
for (auto &a : tup->items) {
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back({"", a->getType()});
}
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall();
auto kwCls =
in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name);
seqassert(kwCls, "cannot find {}",
expr->args[2].value->getType()->getClass()->name);
for (size_t i = 0; i < kw->args.size(); i++) {
auto &a = kw->args[i].value;
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back({kwCls->fields[i].name, a->getType()});
}

// Case: passing argument types via *args
auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple");
for (auto &a : tup->items) {
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.emplace_back("", a->getType());
}
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall();
auto kwCls =
in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name);
seqassert(kwCls, "cannot find {}", expr->args[2].value->getType()->getClass()->name);
for (size_t i = 0; i < kw->args.size(); i++) {
auto &a = kw->args[i].value;
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.emplace_back(kwCls->fields[i].name, a->getType());
}

if (typ->getUnion()) {
Expand Down
20 changes: 12 additions & 8 deletions codon/parser/visitors/typecheck/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ void TypecheckVisitor::visit(YieldExpr *expr) {
/// Also partialize functions if they are being returned.
/// See @c wrapExpr for more details.
void TypecheckVisitor::visit(ReturnStmt *stmt) {
if (transform(stmt->expr)) {
if (!stmt->expr && ctx->getRealizationBase()->type &&
ctx->getRealizationBase()->type->getFunc()->ast->hasAttr(Attr::IsGenerator)) {
stmt->setDone();
} else {
if (!stmt->expr) {
stmt->expr = N<CallExpr>(N<IdExpr>("NoneType"));
}
transform(stmt->expr);
// Wrap expression to match the return type
if (!ctx->getRealizationBase()->returnType->getUnbound())
if (!wrapExpr(stmt->expr, ctx->getRealizationBase()->returnType)) {
Expand All @@ -44,26 +51,23 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
}

unify(ctx->getRealizationBase()->returnType, stmt->expr->type);
} else {
// Just set the expr for the translation stage. However, do not unify the return
// type! This might be a `return` in a generator.
stmt->expr = transform(N<CallExpr>(N<IdExpr>("NoneType")));
}

// If we are not within conditional block, ignore later statements in this function.
// Useful with static if statements.
if (!ctx->blockLevel)
ctx->returnEarly = true;

if (stmt->expr->isDone())
if (!stmt->expr || stmt->expr->isDone())
stmt->setDone();
}

/// Typecheck yield statements. Empty yields assume `NoneType`.
void TypecheckVisitor::visit(YieldStmt *stmt) {
stmt->expr = transform(stmt->expr ? stmt->expr : N<CallExpr>(N<IdExpr>("NoneType")));
unify(ctx->getRealizationBase()->returnType,
ctx->instantiateGeneric(ctx->getType("Generator"), {stmt->expr->type}));

auto t = ctx->instantiateGeneric(ctx->getType("Generator"), {stmt->expr->type});
unify(ctx->getRealizationBase()->returnType, t);

if (stmt->expr->isDone())
stmt->setDone();
Expand Down
7 changes: 4 additions & 3 deletions codon/parser/visitors/typecheck/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,10 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
}
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
auto s = ctx->generateTuple(te.size());
ExprPtr e =
transform(N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(s), "__new__"), te)));
return {true, e};
}
}
Expand Down
Loading

0 comments on commit d23c8c7

Please sign in to comment.