diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 778105168..3f3ad641b 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -57,7 +57,7 @@ struct GeneralizationConstraint struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; const AstNode* nextAstFragment; DenseHashMap* astForInNextTypes; @@ -192,13 +192,7 @@ struct HasIndexerConstraint TypeId indexType; }; -struct AssignConstraint -{ - TypeId lhsType; - TypeId rhsType; -}; - -// assign lhsType propName rhsType +// assignProp lhsType propName rhsType // // Assign a value of type rhsType into the named property of lhsType. @@ -212,6 +206,12 @@ struct AssignPropConstraint /// populate astTypes during constraint resolution. Nothing should ever /// block on it. TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; }; struct AssignIndexConstraint @@ -226,13 +226,13 @@ struct AssignIndexConstraint TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; }; @@ -254,7 +254,7 @@ struct ReducePackConstraint using ConstraintV = Variant; + AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 3e1861ea5..b540b82f0 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,6 +118,8 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; + DenseHashMap> localTypes{nullptr}; + DcrLogger* logger; ConstraintGenerator(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, @@ -354,6 +356,8 @@ struct ConstraintGenerator */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + // Record the fact that a particular local has a particular type in at least // one of its states. void recordInferredBinding(AstLocal* local, TypeId ty); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 58361dde6..902dd15dc 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -142,7 +142,6 @@ struct ConstraintSolver std::pair> tryDispatchSetIndexer( NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const AssignConstraint& c, NotNull constraint); bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); @@ -158,8 +157,7 @@ struct ConstraintSolver bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); @@ -168,14 +166,18 @@ struct ConstraintSolver /** * Generate constraints to unpack the types of srcTypes and assign each - * value to the corresponding LocalType in destTypes. + * value to the corresponding BlockedType in destTypes. * - * @param destTypes A finite TypePack comprised of LocalTypes. + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. * @param srcTypes A TypePack that represents rvalues to be assigned. * @returns The underlying UnpackConstraint. There's a bit of code in * iteration that needs to pass blocks on to this constraint. */ - NotNull unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint); + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 47161886c..6105ede3d 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -86,24 +86,6 @@ struct FreeType TypeId upperBound = nullptr; }; -/** A type that tracks the domain of a local variable. - * - * We consider each local's domain to be the union of all types assigned to it. - * We accomplish this with LocalType. Each time we dispatch an assignment to a - * local, we accumulate this union and decrement blockCount. - * - * When blockCount reaches 0, we can consider the LocalType to be "fully baked" - * and replace it with the union we've built. - */ -struct LocalType -{ - TypeId domain; - int blockCount = 0; - - // Used for debugging - std::string name; -}; - struct GenericType { // By default, generics are global, with a synthetic name @@ -148,6 +130,7 @@ struct BlockedType Constraint* getOwner() const; void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints @@ -471,6 +454,11 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. @@ -669,9 +657,9 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index 130c0c3c3..bbf3a63a9 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -69,7 +69,6 @@ struct Unifier2 */ bool unify(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy); - bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 40dccbd29..ff0656d6c 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -100,10 +100,6 @@ struct GenericTypeVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const LocalType& ftv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); @@ -248,11 +244,6 @@ struct GenericTypeVisitor else visit(ty, *ftv); } - else if (auto lt = get(ty)) - { - if (visit(ty, *lt)) - traverse(lt->domain); - } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index a96e58668..371ace2ea 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -271,11 +271,6 @@ class TypeCloner t->upperBound = shallowClone(t->upperBound); } - void cloneChildren(LocalType* t) - { - t->domain = shallowClone(t->domain); - } - void cloneChildren(GenericType* t) { // TOOD: clone upper bounds. diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index bd31beff7..7b3377cb8 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -81,7 +81,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto itc = get(*this)) { - rci.traverse(itc->variables); + for (TypeId ty : itc->variables) + rci.traverse(ty); // `IterableConstraints` should not mutate `iterator`. } else if (auto nc = get(*this)) @@ -106,11 +107,6 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const rci.traverse(hic->resultType); // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. } - else if (auto ac = get(*this)) - { - rci.traverse(ac->lhsType); - rci.traverse(ac->rhsType); - } else if (auto apc = get(*this)) { rci.traverse(apc->lhsType); @@ -124,7 +120,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto uc = get(*this)) { - rci.traverse(uc->resultPack); + for (TypeId ty : uc->resultPack) + rci.traverse(ty); // `UnpackConstraint` should not mutate `sourcePack`. } else if (auto rpc = get(*this)) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 12648eb01..9d825408a 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -28,6 +28,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -246,6 +247,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) if (logger) logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } } TypeId ConstraintGenerator::freshType(const ScopePtr& scope) @@ -310,7 +322,8 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio std::optional ty = lookup(scope, location, operand, /*prototype*/ false); if (!ty) { - ty = arena->addType(LocalType{builtinTypes->neverType}); + ty = arena->addType(BlockedType{}); + localTypes[*ty] = {}; rootScope->lvalueTypes[operand] = *ty; } @@ -703,7 +716,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat { const Location location = local->location; - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value}); + TypeId assignee = arena->addType(BlockedType{}); + localTypes[assignee] = {}; assignees.push_back(assignee); @@ -740,7 +754,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat if (hasAnnotation) { for (size_t i = 0; i < statLocal->vars.size; ++i) - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]}); + { + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(annotatedTypes[i]); + } TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); @@ -750,15 +769,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat std::vector valueTypes; valueTypes.reserve(statLocal->vars.size); - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); - auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack}); + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < statLocal->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]}); + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(valueTypes[i]); } } @@ -860,25 +894,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); + TypeId assignee = arena->addType(BlockedType{}); variableTypes.push_back(assignee); + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].push_back(assignee); + if (var->annotation) { TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); loopScope->bindings[var] = Binding{annotationTy, var->location}; - addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy}); + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); } else - loopScope->bindings[var] = Binding{assignee, var->location}; + loopScope->bindings[var] = Binding{loopVar, var->location}; DefId def = dfg->getDef(var); - loopScope->lvalueTypes[def] = assignee; + loopScope->lvalueTypes[def] = loopVar; } - TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); auto iterable = addConstraint( - loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); + + for (TypeId var: variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } Checkpoint start = checkpoint(this); visit(loopScope, forIn->body); @@ -1105,14 +1148,31 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass std::vector valueTypes; valueTypes.reserve(assign->vars.size); - for (size_t i = 0; i < assign->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) + { + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. - auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack}); + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < assign->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); visitLValue(scope, assign->vars.data[i], valueTypes[i]); } @@ -1393,7 +1453,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); FunctionType* ftv = getMutable(fnType); - ftv->isCheckedFunction = global->checkedFunction; + ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false; ftv->argNames.reserve(global->paramNames.size); for (const auto& el : global->paramNames) @@ -1599,9 +1659,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* mt = arena->addType(BlockedType{}); unpackedTypes.emplace_back(mt); - TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); getMutable(mt)->setOwner(c); if (auto b = getMutable(target); b && b->getOwner() == nullptr) b->setOwner(c); @@ -1842,7 +1901,37 @@ Inference ConstraintGenerator::checkIndexName( const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) { TypeId obj = check(scope, indexee).ty; - TypeId result = arena->addType(BlockedType{}); + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(localDomain->front()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + getMutable(result)->setOwner(c); + } if (key) { @@ -1852,10 +1941,6 @@ Inference ConstraintGenerator::checkIndexName( scope->rvalueRefinements[key->def] = result; } - auto c = - addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); - getMutable(result)->setOwner(c); - if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; else @@ -2242,18 +2327,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (ty) { - if (auto lt = getMutable(*ty)) - ++lt->blockCount; - else if (auto ut = getMutable(*ty)) - { - for (TypeId optTy : ut->options) - if (auto lt = getMutable(optTy)) - ++lt->blockCount; - } + std::vector* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->push_back(rhsType); } else { - ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value}); + ty = arena->addType(BlockedType{}); + localTypes[*ty].push_back(rhsType); if (annotatedTy) { @@ -2277,7 +2358,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (annotatedTy) addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, local->location, AssignConstraint{*ty, rhsType}); + + if (auto localDomain = localTypes.find(*ty)) + localDomain->push_back(rhsType); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) @@ -2289,7 +2372,6 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob rootScope->lvalueTypes[def] = rhsType; addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType}); } } @@ -2298,7 +2380,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e TypeId lhsTy = check(scope, expr->expr).ty; TypeId propTy = arena->addType(BlockedType{}); module->astTypes[expr] = propTy; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) @@ -2310,7 +2395,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e module->astTypes[expr] = propTy; module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. std::string propName{constantString->value.data, constantString->value.size}; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); return; } @@ -2775,7 +2863,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - ftv.isCheckedFunction = fn->checkedFunction; + ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false; // This replicates the behavior of the appropriate FunctionType // constructors. @@ -2977,8 +3065,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat return Inference{*f, refinement}; TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); getMutable(typeResult)->setOwner(c); return Inference{typeResult, refinement}; @@ -3075,6 +3162,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As program->visit(&gp); } +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (auto localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) { if (InferredBinding* ib = inferredBindings.find(local)) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b0f27911c..07fc26fb4 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -532,8 +532,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint); - else if (auto uc = get(*constraint)) - success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) @@ -686,7 +684,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullanyTypePack, c.variables); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } @@ -696,21 +695,35 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); - TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); - getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + TypeId tableTy = arena->addType(TableType{ + TableType::Props{}, + TableIndexer{keyTy, valueTy}, + TypeLevel{}, + constraint->scope, + TableState::Free + }); - pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); + unify(constraint, nextTy, tableTy); auto it = begin(c.variables); auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + { + bindBlockedType(*it, valueTy, valueTy, constraint); + ++it; + } + + while (it != endIt) + { + bindBlockedType(*it, builtinTypes->nilType, builtinTypes->nilType, constraint); + ++it; + } return true; } @@ -721,11 +734,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - TypeId firstIndexTy = builtinTypes->nilType; - if (iterator.head.size() >= 3) - firstIndexTy = iterator.head[2]; - - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); } else @@ -1310,6 +1319,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType) || get(subjectType)) return block(subjectType, constraint); + if (const TableType* subjectTable = getTableType(subjectType)) + { + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } + } + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { @@ -1517,7 +1534,10 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; - return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + if (ok) + unblock(c.resultType, constraint->location); + return ok; } std::pair> ConstraintSolver::tryDispatchSetIndexer( @@ -1596,46 +1616,6 @@ std::pair> ConstraintSolver::tryDispatchSetIndexer( return {true, std::nullopt}; } -bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull constraint) -{ - const TypeId lhsTy = follow(c.lhsType); - const TypeId rhsTy = follow(c.rhsType); - - if (!get(lhsTy) && isBlocked(lhsTy)) - return block(lhsTy, constraint); - - auto tryExpand = [&](TypeId ty) { - LocalType* lt = getMutable(ty); - if (!lt) - return; - - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result; - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; - - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - }; - - if (auto ut = get(lhsTy)) - { - // FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes. - for (TypeId t : ut) - tryExpand(t); - } - else if (get(lhsTy)) - tryExpand(lhsTy); - else - unify(constraint, rhsTy, lhsTy); - - unblock(lhsTy, constraint->location); - - return true; -} - bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) { TypeId lhsType = follow(c.lhsType); @@ -1753,6 +1733,14 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), rhsType); lhsTable->props[propName] = Property::rw(rhsType); + + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + unblock(lhsType, constraint->location); + } + return true; } } @@ -1927,24 +1915,14 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) return block(sourcePack, constraint); - if (isBlocked(resultPack)) - { - LUAU_ASSERT(canMutate(resultPack, constraint)); - LUAU_ASSERT(resultPack != sourcePack); - emplaceTypePack(asMutable(resultPack), sourcePack); - unblock(resultPack, constraint->location); - return true; - } - - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - auto resultIter = begin(resultPack); - auto resultEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; while (resultIter != resultEnd) @@ -2080,18 +2058,22 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + bindBlockedType(*it, valueTy, valueTy, constraint); return true; } auto unpack = [&](TypeId ty) { - for (TypeId varTy : c.variables) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty}); + for (TypeId varTy: c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bindBlockedType(varTy, ty, ty, constraint); + } }; if (get(iteratorTy)) @@ -2129,27 +2111,18 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(constraint, c.variables, expectedVariablePack); - - auto [variableTys, variablesTail] = flatten(c.variables); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); - // the local types for the indexer _should_ be all set after unification - for (TypeId ty : variableTys) + for (size_t i = 0; i < c.variables.size(); ++i) { - if (auto lt = getMutable(ty)) - { - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); - LUAU_ASSERT(0 <= lt->blockCount); + unify(constraint, c.variables[i], expectedVariables[i]); - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - } + bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint); + unblock(c.variables[i], constraint->location); } } else @@ -2213,26 +2186,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) unpack(builtinTypes->unknownType); else + { unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } - const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. LUAU_ASSERT(nextFn); @@ -2267,27 +2230,18 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -NotNull ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint) +NotNull ConstraintSolver::unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint) { - std::vector unpackedTys; - for (TypeId _ty : destTypes) - { - (void) _ty; - unpackedTys.push_back(arena->addType(BlockedType{})); - } + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); - TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys}); - auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes}); - - size_t i = 0; - for (TypeId varTy : destTypes) + for (TypeId t: destTypes) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]}); - getMutable(unpackedTys[i])->setOwner(unpackConstraint); - ++i; + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); } - return unpackConstraint; + return c; } std::pair, std::optional> ConstraintSolver::lookupTableProp(NotNull constraint, TypeId subjectType, @@ -2808,9 +2762,6 @@ bool ConstraintSolver::isBlocked(TypeId ty) { ty = follow(ty); - if (auto lt = get(ty)) - return lt->blockCount > 0; - if (auto tfit = get(ty)) return uninhabitedTypeFamilies.contains(ty) == false; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 78b76a78e..91d8006a4 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -319,9 +320,9 @@ declare os: { clock: () -> number, } -declare function @checked require(target: any): any +@checked declare function require(target: any): any -declare function @checked getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -363,7 +364,7 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function @checked newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, @@ -451,7 +452,7 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionLuaSrc; // Annotates each non generic function as checked - if (FFlag::LuauCheckedEmbeddedDefinitions2) + if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax) result = kBuiltinDefinitionLuaSrcChecked; return result; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5261c2116..7823f3d4a 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1196,12 +1196,6 @@ struct InternalTypeFinder : TypeOnceVisitor return false; } - bool visit(TypeId, const LocalType&) override - { - LUAU_ASSERT(false); - return false; - } - bool visit(TypePackId, const BlockedTypePack&) override { LUAU_ASSERT(false); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5b14fd5f2..7ce50284e 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1815,12 +1815,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (!isCacheable(there)) here.isCacheable = false; } - else if (auto lt = get(there)) - { - // FIXME? This is somewhat questionable. - // Maybe we should assert because this should never happen? - unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars); - } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) @@ -3095,7 +3089,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (get(there) || get(there) || get(there) || get(there) || - get(there) || get(there)) + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3104,10 +3098,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type here.isCacheable = false; return intersectNormals(here, thereNorm); } - else if (auto lt = get(there)) - { - return intersectNormalWithTy(here, lt->domain, seenSetTypes); - } NormalizedTyvars tyvars = std::move(here.tyvars); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index bc899798c..ea9c31781 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -24,8 +24,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a // We decline to copy them. if constexpr (std::is_same_v) return ty; - else if constexpr (std::is_same_v) - return ty; else if constexpr (std::is_same_v) { // This should never happen, but visit() cannot see it. diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9093b38aa..17b595b18 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index) visitChild(t.upperBound, index, "[upperBound]"); } } - else if constexpr (std::is_same_v) - { - formatAppend(result, "LocalType"); - finishNodeLabel(ty); - finishNode(); - - visitChild(t.domain, 1, "[domain]"); - } else if constexpr (std::is_same_v) { formatAppend(result, "AnyType %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4e81a870c..dca041a2e 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -100,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor return false; } - bool visit(TypeId ty, const LocalType& lt) override - { - if (!visited.insert(ty)) - return false; - - traverse(lt.domain); - - return false; - } - bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty)) @@ -525,21 +515,6 @@ struct TypeStringifier } } - void operator()(TypeId ty, const LocalType& lt) - { - state.emit("l-"); - state.emit(lt.name); - if (FInt::DebugLuauVerboseTypeNames >= 1) - { - state.emit("["); - state.emit(lt.blockCount); - state.emit("]"); - } - state.emit("=["); - stringify(lt.domain); - state.emit("]"); - } - void operator()(TypeId, const BoundType& btv) { stringify(btv.boundTo); @@ -1724,6 +1699,18 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { auto go = [&opts](auto&& c) -> std::string { @@ -1754,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); return variableStr + " ~ iterate " + iteratorStr; } @@ -1791,14 +1778,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) - return "assign " + tos(c.lhsType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 85b8849f7..d78bf1575 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1182,11 +1182,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index b7a54e3d9..71cac6fd9 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} + PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f1fe83eeb..c0294fc9c 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,10 +338,6 @@ class TypeRehydrationVisitor { return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } - AstType* operator()(const LocalType& lt) - { - return Luau::visit(*this, lt.domain->ty); - } AstType* operator()(const UnionType& uv) { AstArray unionTypes; diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 3a0483a6f..89de19126 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -447,7 +447,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati bool isPending(TypeId ty, ConstraintSolver* solver) { - return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } template @@ -567,7 +567,7 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. - if (isPending(operandTy, ctx->solver) || get(operandTy)) + if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. @@ -1427,12 +1427,6 @@ struct FindRefinementBlockers : TypeOnceVisitor return false; } - bool visit(TypeId ty, const LocalType&) override - { - found.insert(ty); - return false; - } - bool visit(TypeId ty, const ClassType&) override { return false; diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index c8db5335a..f46c33729 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; - if (auto subLocal = getMutable(subTy)) - { - subLocal->domain = mkUnion(subLocal->domain, superTy); - expandedFreeTypes[subTy].push_back(superTy); - } - auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 993116d60..e8479e09e 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -60,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -172,6 +174,10 @@ class AstNode { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -193,6 +199,28 @@ class AstNode Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -384,13 +412,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, + const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -810,20 +840,22 @@ class AstStatDeclareFunction : public AstStat const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction); + AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; - bool checkedFunction; }; struct AstDeclaredClassProp @@ -936,17 +968,20 @@ class AstTypeFunction : public AstType AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction); + AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; - bool checkedFunction; }; class AstTypeTypeof : public AstType @@ -1105,6 +1140,11 @@ class AstVisitor return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index e111030df..f6ac28ad9 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,8 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, @@ -115,14 +117,20 @@ struct Lexeme ReservedTrue, ReservedUntil, ReservedWhile, - ReservedChecked, Reserved_END }; Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -135,9 +143,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e97df66b7..c1fd43ea6 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -82,8 +82,8 @@ class Parser // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,11 +114,25 @@ class Parser AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); @@ -130,7 +144,7 @@ class Parser // `declare global' Name: Type | // `declare function' Name`(' [parlist] `)' [`:` Type] - AstStat* parseDeclaration(const Location& start); + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -143,7 +157,7 @@ class Parser // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -176,10 +190,10 @@ class Parser AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); - AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, - bool isCheckedFunction = false); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation); AstType* parseTableType(bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); @@ -393,6 +407,7 @@ class Parser std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index bb82e0be9..4c9563079 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,7 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -16,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -161,11 +173,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, - const std::optional& argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, + AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -696,27 +709,27 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes() , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(false) { } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes(attributes) , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(checkedFunction) { } @@ -729,6 +742,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) @@ -820,25 +846,26 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes() , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(false) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction) +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(checkedFunction) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } @@ -852,6 +879,19 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 715774590..8e9b3be94 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false) namespace Luau { @@ -102,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + + return length; } static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while", "@checked"}; + "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -191,6 +200,10 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -278,7 +291,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -994,14 +1007,11 @@ Lexeme Lexer::readNext() } case '@': { - // We're trying to lex the token @checked - LUAU_ASSERT(peekch() == '@'); - - std::pair maybeChecked = readName(); - if (maybeChecked.second != Lexeme::ReservedChecked) - return Lexeme(Location(start, position()), Lexeme::Error); - - return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); + if (FFlag::LuauAttributeSyntax) + { + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); + } } default: if (isDigit(peekch())) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 5ca480e80..d80878d50 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,11 +17,20 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(LuauAttributeSyntax) LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -280,7 +289,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -299,13 +310,16 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + if (FFlag::LuauAttributeSyntax) + return parseAttributeStat(); default:; } @@ -343,7 +357,7 @@ AstStat* Parser::parseStat() if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -652,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -665,16 +679,125 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstArray attributes = Parser::parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + default: + return reportStatError(lexer.current().location, {}, {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -694,7 +817,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -704,6 +827,12 @@ AstStat* Parser::parseLocal() } else { + if (FFlag::LuauAttributeSyntax && attributes.size != 0) + { + return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead", + lexer.current().toString().c_str()); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -831,18 +960,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() return AstDeclaredClassProp{fnName.name, fnType, true}; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - bool checkedFunction = false; - if (lexer.current().type == Lexeme::ReservedChecked) - { - checkedFunction = true; - nextLexeme(); - } Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); @@ -880,8 +1008,8 @@ AstStat* Parser::parseDeclaration(const Location& start) if (vararg && !varargAnnotation) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); - return allocator.alloc(Location(start, end), globalName.name, generics, genericPacks, - AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); + return allocator.alloc(Location(start, end), attributes, globalName.name, generics, genericPacks, + AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); } else if (AstName(lexer.current().name) == "class") { @@ -1035,7 +1163,7 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes) { Location start = matchFunction.location; @@ -1087,7 +1215,7 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); body->hasEnd = hasEnd; - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + return {allocator.alloc(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body, functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } @@ -1296,7 +1424,7 @@ std::pair Parser::parseReturnType() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } @@ -1435,7 +1563,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) // ReturnType ::= Type | `(' TypeList `)' // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1483,11 +1611,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstArray> paramNames = copy(names); - return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1512,7 +1641,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray( - Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList); } // Type ::= @@ -1666,7 +1795,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext || !FFlag::LuauAttributeSyntax) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; @@ -1754,14 +1897,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } - else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) - { - nextLexeme(); - return parseFunctionType(allowPack, /* isCheckedFunction */ true); - } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionType(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { @@ -2259,7 +2397,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray({nullptr, 0})).first; } else if (lexer.current().type == Lexeme::Number) { @@ -2689,7 +2827,7 @@ std::optional> Parser::parseCharArray() LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2734,7 +2872,7 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2807,7 +2945,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3162,11 +3300,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 171e9197a..7dd05660b 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -73,12 +73,39 @@ struct CompilationResult }; struct IrBuilder; +struct IrOp; using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); using HostVectorNamecallHandler = bool (*)( IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = bool (*)( + IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + struct HostIrHooks { // Suggest result type of a vector field access @@ -97,6 +124,34 @@ struct HostIrHooks // All other arguments can be of any type // Guards should take a VM exit to 'pcpos' HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; }; struct CompilationOptions diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b00fffab0..60af706f2 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -290,6 +290,11 @@ enum class IrCmd : uint8_t // C: block TRY_CALL_FASTGETTM, + // Create new tagged userdata + // A: int (size) + // B: int (tag) + NEW_USERDATA, + // Convert integer into a double number // A: int INT_TO_NUM, @@ -460,6 +465,13 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_BUFFER_LEN, + // Guard against userdata tag mismatch + // A: pointer (userdata) + // B: int (tag) + // C: block/vmexit/undef + // When undef is specified instead of a block, execution is aborted on check failure + CHECK_USERDATA_TAG, + // Special operations // Check interrupt handler diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 55b868221..8486921e9 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -11,6 +11,7 @@ namespace CodeGen { struct IrBuilder; +enum class HostMetamethod; inline bool isJumpD(LuauOpcode op) { @@ -129,6 +130,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: return true; default: break; @@ -182,6 +184,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: case IrCmd::NUM_TO_INT: @@ -245,6 +248,8 @@ bool isGCO(uint8_t tag); bool isUserdataBytecodeType(uint8_t ty); bool isCustomUserdataBytecodeType(uint8_t ty); +HostMetamethod tmToHostMetamethod(int tm); + // Manually add or remove use of an operand void addUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index aed8c7634..fc8eb9000 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) namespace Luau { @@ -546,6 +547,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) } } +static HostMetamethod opcodeToHostMetamethod(LuauOpcode op) +{ + switch (op) + { + case LOP_ADD: + return HostMetamethod::Add; + case LOP_SUB: + return HostMetamethod::Sub; + case LOP_MUL: + return HostMetamethod::Mul; + case LOP_DIV: + return HostMetamethod::Div; + case LOP_IDIV: + return HostMetamethod::Idiv; + case LOP_MOD: + return HostMetamethod::Mod; + case LOP_POW: + return HostMetamethod::Pow; + case LOP_ADDK: + return HostMetamethod::Add; + case LOP_SUBK: + return HostMetamethod::Sub; + case LOP_MULK: + return HostMetamethod::Mul; + case LOP_DIVK: + return HostMetamethod::Div; + case LOP_IDIVK: + return HostMetamethod::Idiv; + case LOP_MODK: + return HostMetamethod::Mod; + case LOP_POWK: + return HostMetamethod::Pow; + case LOP_SUBRK: + return HostMetamethod::Sub; + case LOP_DIVRK: + return HostMetamethod::Div; + default: + CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod"); + } + + return HostMetamethod::Add; +} + void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpTargets) { Proto* proto = function.proto; @@ -760,22 +804,50 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_ANY; - if (bcType.a == LBC_TYPE_VECTOR) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - if (str->len == 1) + if (bcType.a == LBC_TYPE_VECTOR) { - // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" - char ch = field[0] | ' '; + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } - if (ch == 'x' || ch == 'y' || ch == 'z') - regTags[ra] = LBC_TYPE_NUMBER; + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } + else if (isCustomUserdataBytecodeType(bcType.a)) + { + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType) + regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len); + } + } + else + { + if (bcType.a == LBC_TYPE_VECTOR) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) - regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + } } bcType.result = regTags[ra]; @@ -812,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -841,6 +916,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -859,6 +939,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -879,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -908,6 +994,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -926,6 +1017,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -945,6 +1039,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -972,6 +1069,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -1000,6 +1102,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a)) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus); bcType.result = regTags[ra]; break; @@ -1140,12 +1244,25 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) bcType.result = LBC_TYPE_FUNCTION; - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); + } + else + { + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + } } } break; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 05ac90137..06f64955c 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -258,39 +258,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderA64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::A64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderA64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 24fedd9a4..2633f5ba1 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace A64 @@ -15,7 +14,6 @@ namespace A64 class AssemblyBuilderA64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 7788d099d..ae9e41f1e 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -14,8 +14,8 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) -LUAU_FASTINT(LuauCodeGenBlockSize) -LUAU_FASTINT(LuauCodeGenMaxTotalSize) +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) namespace Luau { diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 973829ca0..ad231e764 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -14,6 +14,7 @@ #include "lstate.h" #include "lstring.h" #include "ltable.h" +#include "ludata.h" #include @@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +Udata* newUserdata(lua_State* L, size_t s, int tag) +{ + Udata* u = luaU_newudata(L, s, tag); + + if (Table* h = L->global->udatamt[tag]) + { + u->metatable = h; + + luaC_objbarrier(L, u, h); + } + + return u; +} + // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 515a81f0a..15d4c95d2 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +Udata* newUserdata(lua_State* L, size_t s, int tag); + #define CALL_FALLBACK_YIELD 1 Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 7f4a9e0c6..b8df37744 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -186,39 +186,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderX64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::X64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderX64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index eb6ab81c4..ce360b230 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace X64 @@ -15,7 +14,6 @@ namespace X64 class AssemblyBuilderX64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 894570d93..d61fd2a73 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -22,8 +22,6 @@ namespace Luau namespace CodeGen { -struct NativeState; - namespace A64 { diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index c29479e15..f88944e55 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -26,7 +26,6 @@ namespace CodeGen { enum class IrCondition : uint8_t; -struct NativeState; struct IrOp; namespace X64 diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index c47a0b8fe..a82ee894d 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -199,6 +199,8 @@ const char* getCmdName(IrCmd cmd) return "TRY_NUM_TO_INDEX"; case IrCmd::TRY_CALL_FASTGETTM: return "TRY_CALL_FASTGETTM"; + case IrCmd::NEW_USERDATA: + return "NEW_USERDATA"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::UINT_TO_NUM: @@ -257,6 +259,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_NODE_VALUE"; case IrCmd::CHECK_BUFFER_LEN: return "CHECK_BUFFER_LEN"; + case IrCmd::CHECK_USERDATA_TAG: + return "CHECK_USERDATA_TAG"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index c8cc07f4f..ea83bb993 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -13,6 +13,9 @@ LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenSplitDoarith) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) namespace Luau { @@ -1083,6 +1086,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regA64 = regs.takeReg(x0, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + regs.spill(build, index); + build.mov(x0, rState); + build.mov(x1, intOp(inst.a)); + build.mov(x2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata))); + build.blr(x3); + inst.regA64 = regs.takeReg(x0, index); + break; + } case IrCmd::INT_TO_NUM: { inst.regA64 = regs.allocReg(KindA64::d, index); @@ -1677,6 +1693,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) finalizeTargetLabel(inst.d, fresh); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = getTargetLabel(inst.c, fresh); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag))); + + if (FFlag::LuauCodegenUserdataOpsFixA64) + build.cmp(temp, intOp(inst.b)); + else + build.cmp(temp, tagOp(inst.b)); + + build.b(ConditionA64::NotEqual, fail); + finalizeTargetLabel(inst.c, fresh); + break; + } case IrCmd::INTERRUPT: { regs.spill(build, index); @@ -2308,7 +2342,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsb(inst.regA64, addr); break; @@ -2317,7 +2351,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrb(inst.regA64, addr); break; @@ -2326,7 +2360,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI8: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strb(temp, addr); break; @@ -2335,7 +2369,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsh(inst.regA64, addr); break; @@ -2344,7 +2378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrh(inst.regA64, addr); break; @@ -2353,7 +2387,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI16: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strh(temp, addr); break; @@ -2362,7 +2396,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI32: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2371,7 +2405,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI32: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2381,7 +2415,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::d, index); RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(temp, addr); build.fcvt(inst.regA64, temp); @@ -2392,7 +2426,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp2 = regs.allocTemp(KindA64::s); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.fcvt(temp2, temp1); build.str(temp2, addr); @@ -2402,7 +2436,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READF64: { inst.regA64 = regs.allocReg(KindA64::d, index); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2411,7 +2445,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEF64: { RegisterA64 temp = tempDouble(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2639,32 +2673,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) } } -AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) +AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) - { - RegisterA64 temp = regs.allocTemp(KindA64::x); - build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw - return mem(temp, offsetof(Buffer, data)); - } - else if (indexOp.kind == IrOpKind::Constant) + if (FFlag::LuauCodegenUserdataOps) { - // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding - if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) - return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); - // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset - if (intOp(indexOp) < 0) - return mem(regOp(bufferOp), offsetof(Buffer, data)); + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, dataOffset); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + dataOffset <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset)); - RegisterA64 temp = regs.allocTemp(KindA64::x); - emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); - return mem(temp, offsetof(Buffer, data)); + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), dataOffset); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, dataOffset); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } else { - CODEGEN_ASSERT(!"Unsupported instruction form"); - return noreg; + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, offsetof(Buffer, data)); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), offsetof(Buffer, data)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, offsetof(Buffer, data)); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5fb7f2b8a..5f13f58e4 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -44,7 +44,7 @@ struct IrLoweringA64 RegisterA64 tempInt(IrOp op); RegisterA64 tempUint(IrOp op); AddressA64 tempAddr(IrOp op, int offset); - AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); + AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag); // May emit restore instructions RegisterA64 regOp(IrOp op); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 66609cb75..00768c70f 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,6 +15,9 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) + namespace Luau { namespace CodeGen @@ -905,6 +908,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regX64 = regs.takeReg(rax, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, intOp(inst.a)); + callWrap.addArgument(SizeX64::dword, intOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]); + inst.regX64 = regs.takeReg(rax, index); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); @@ -1350,6 +1365,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); + break; + } case IrCmd::INTERRUPT: { unsigned pcpos = uintOp(inst.a); @@ -1895,71 +1918,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI8: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); - build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); + build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI16: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); - build.mov(word[bufferAddrOp(inst.a, inst.b)], value); + build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI32: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI32: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); - build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); + build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READF32: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF32: - storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); + storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c); break; case IrCmd::BUFFER_READF64: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); + build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF64: @@ -1967,11 +1990,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg); } else if (inst.c.kind == IrOpKind::Inst) { - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c)); } else { @@ -2190,12 +2213,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op) return inst.regX64; } -OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) +OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) - return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); - else if (indexOp.kind == IrOpKind::Constant) - return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + if (FFlag::LuauCodegenUserdataOps) + { + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); + + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset; + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + dataOffset; + } + else + { + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + } CODEGEN_ASSERT(!"Unsupported instruction form"); return noreg; diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 5fb7b0fab..8fb311ea5 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -50,7 +50,7 @@ struct IrLoweringX64 OperandX64 memRegUintOp(IrOp op); OperandX64 memRegTagOp(IrOp op); RegisterX64 regOp(IrOp op); - OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); + OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); IrConst constOp(IrOp op) const; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 93073a923..5798f3e95 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) +LUAU_FASTFLAG(LuauCodegenUserdataOps) namespace Luau { @@ -444,6 +445,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, return; } + if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b))) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm)); + return; + } + IrOp fallback; // fast-path: number @@ -585,6 +597,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) return; } + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + return; + } + IrOp fallback; IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -606,8 +629,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst( - IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + + if (FFlag::LuauCodegenUserdataOps) + { + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + } + else + { + build.inst( + IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + } + build.inst(IrCmd::JUMP, next); } } @@ -619,6 +651,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + return; + } + IrOp fallback = build.block(IrBlockKind::Fallback); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -638,7 +681,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + + if (FFlag::LuauCodegenUserdataOps) + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + else + build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + build.inst(IrCmd::JUMP, next); } @@ -1229,10 +1277,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) return; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess) + { + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos)) + return; + } + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return; } @@ -1267,7 +1324,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1413,10 +1470,26 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) return false; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return false; } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index afc6ba5ae..d1bfca45a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -99,6 +99,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::TRY_NUM_TO_INDEX: return IrValueKind::Int; case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: return IrValueKind::Pointer; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: @@ -135,6 +136,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: case IrCmd::INTERRUPT: case IrCmd::CHECK_GC: case IrCmd::BARRIER_OBJ: @@ -262,6 +264,44 @@ bool isCustomUserdataBytecodeType(uint8_t ty) return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; } +HostMetamethod tmToHostMetamethod(int tm) +{ + switch (TMS(tm)) + { + case TM_ADD: + return HostMetamethod::Add; + case TM_SUB: + return HostMetamethod::Sub; + case TM_MUL: + return HostMetamethod::Mul; + case TM_DIV: + return HostMetamethod::Div; + case TM_IDIV: + return HostMetamethod::Idiv; + case TM_MOD: + return HostMetamethod::Mod; + case TM_POW: + return HostMetamethod::Pow; + case TM_UNM: + return HostMetamethod::Minus; + case TM_EQ: + return HostMetamethod::Equal; + case TM_LT: + return HostMetamethod::LessThan; + case TM_LE: + return HostMetamethod::LessEqual; + case TM_LEN: + return HostMetamethod::Length; + case TM_CONCAT: + return HostMetamethod::Concat; + default: + CODEGEN_ASSERT(!"invalid tag method for host"); + break; + } + + return HostMetamethod::Add; +} + void kill(IrFunction& function, IrInst& inst) { CODEGEN_ASSERT(inst.useCount == 0); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index b3d07491a..248f0cd31 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,114 +14,13 @@ #include #include -LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) -LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { namespace CodeGen { -NativeState::NativeState() - : NativeState(nullptr, nullptr) -{ -} - -NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) - : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} -{ -} - -NativeState::~NativeState() = default; - -void initFunctions(NativeState& data) -{ - static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); - memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); - - data.context.luaV_lessthan = luaV_lessthan; - data.context.luaV_lessequal = luaV_lessequal; - data.context.luaV_equalval = luaV_equalval; - data.context.luaV_doarith = luaV_doarith; - - data.context.luaV_doarithadd = luaV_doarithimpl; - data.context.luaV_doarithsub = luaV_doarithimpl; - data.context.luaV_doarithmul = luaV_doarithimpl; - data.context.luaV_doarithdiv = luaV_doarithimpl; - data.context.luaV_doarithidiv = luaV_doarithimpl; - data.context.luaV_doarithmod = luaV_doarithimpl; - data.context.luaV_doarithpow = luaV_doarithimpl; - data.context.luaV_doarithunm = luaV_doarithimpl; - - data.context.luaV_dolen = luaV_dolen; - data.context.luaV_gettable = luaV_gettable; - data.context.luaV_settable = luaV_settable; - data.context.luaV_getimport = luaV_getimport; - data.context.luaV_concat = luaV_concat; - - data.context.luaH_getn = luaH_getn; - data.context.luaH_new = luaH_new; - data.context.luaH_clone = luaH_clone; - data.context.luaH_resizearray = luaH_resizearray; - data.context.luaH_setnum = luaH_setnum; - - data.context.luaC_barriertable = luaC_barriertable; - data.context.luaC_barrierf = luaC_barrierf; - data.context.luaC_barrierback = luaC_barrierback; - data.context.luaC_step = luaC_step; - - data.context.luaF_close = luaF_close; - data.context.luaF_findupval = luaF_findupval; - data.context.luaF_newLclosure = luaF_newLclosure; - - data.context.luaT_gettm = luaT_gettm; - data.context.luaT_objtypenamestr = luaT_objtypenamestr; - - data.context.libm_exp = exp; - data.context.libm_pow = pow; - data.context.libm_fmod = fmod; - data.context.libm_log = log; - data.context.libm_log2 = log2; - data.context.libm_log10 = log10; - data.context.libm_ldexp = ldexp; - data.context.libm_round = round; - data.context.libm_frexp = frexp; - data.context.libm_modf = modf; - - data.context.libm_asin = asin; - data.context.libm_sin = sin; - data.context.libm_sinh = sinh; - data.context.libm_acos = acos; - data.context.libm_cos = cos; - data.context.libm_cosh = cosh; - data.context.libm_atan = atan; - data.context.libm_atan2 = atan2; - data.context.libm_tan = tan; - data.context.libm_tanh = tanh; - - data.context.forgLoopTableIter = forgLoopTableIter; - data.context.forgLoopNodeIter = forgLoopNodeIter; - data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; - data.context.forgPrepXnextFallback = forgPrepXnextFallback; - data.context.callProlog = callProlog; - data.context.callEpilogC = callEpilogC; - - data.context.callFallback = callFallback; - - data.context.executeGETGLOBAL = executeGETGLOBAL; - data.context.executeSETGLOBAL = executeSETGLOBAL; - data.context.executeGETTABLEKS = executeGETTABLEKS; - data.context.executeSETTABLEKS = executeSETTABLEKS; - - data.context.executeNAMECALL = executeNAMECALL; - data.context.executeFORGPREP = executeFORGPREP; - data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet; - data.context.executeGETVARARGSConst = executeGETVARARGSConst; - data.context.executeDUPCLOSURE = executeDUPCLOSURE; - data.context.executePREPVARARGS = executePREPVARARGS; - data.context.executeSETLIST = executeSETLIST; -} - void initFunctions(NativeContext& context) { static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); @@ -194,6 +93,9 @@ void initFunctions(NativeContext& context) context.callProlog = callProlog; context.callEpilogC = callEpilogC; + if (FFlag::LuauCodegenUserdataAlloc) + context.newUserdata = newUserdata; + context.callFallback = callFallback; context.executeGETGLOBAL = executeGETGLOBAL; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 2edfc2701..be73815d1 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -94,6 +94,7 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; @@ -116,22 +117,6 @@ struct NativeContext using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); -struct NativeState -{ - NativeState(); - NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext); - ~NativeState(); - - CodeAllocator codeAllocator; - std::unique_ptr unwindBuilder; - - uint8_t* gateData = nullptr; - size_t gateDataSize = 0; - - NativeContext context; -}; - -void initFunctions(NativeState& data); void initFunctions(NativeContext& context); } // namespace CodeGen diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 9135a9ede..4ff49570b 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -16,9 +16,12 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) +LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { @@ -200,6 +203,11 @@ struct ConstPropState checkBufferLenCache.clear(); } + void invalidateUserdataData() + { + useradataTagCache.clear(); + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -417,6 +425,9 @@ struct ConstPropState invalidateValuePropagation(); invalidateHeapTableData(); invalidateHeapBufferData(); + + if (FFlag::LuauCodegenUserdataOps) + invalidateUserdataData(); } IrFunction& function; @@ -446,6 +457,9 @@ struct ConstPropState std::vector checkArraySizeCache; // Additionally, fallback block argument might be different std::vector checkBufferLenCache; // Additionally, fallback block argument might be different + + // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions + std::vector useradataTagCache; // Additionally, fallback block argument might be different }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -1061,6 +1075,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.checkBufferLenCache.push_back(index); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + for (uint32_t prevIdx : state.useradataTagCache) + { + IrInst& prev = function.instructions[prevIdx]; + + if (prev.cmd == IrCmd::CHECK_USERDATA_TAG) + { + if (prev.a != inst.a || prev.b != inst.b) + continue; + } + else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA) + { + if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b) + continue; + } + + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + + return; // Break out from both the loop and the switch + } + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; + } case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_WRITEI8: @@ -1228,6 +1273,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::TRY_CALL_FASTGETTM: break; + case IrCmd::NEW_USERDATA: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: state.substituteOrRecord(inst, index); @@ -1512,6 +1563,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite state.invalidateHeapTableData(); state.invalidateHeapBufferData(); + if (FFlag::LuauCodegenUserdataOps) + state.invalidateUserdataData(); + // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys block->sortkey = startSortkey; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 6c1d6affa..d18b75c5b 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -10,6 +10,7 @@ #include "lobject.h" LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -595,6 +596,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, case IrCmd::CHECK_BUFFER_LEN: state.checkLiveIns(inst.d); break; + case IrCmd::CHECK_USERDATA_TAG: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + state.checkLiveIns(inst.c); + break; case IrCmd::JUMP: // Ideally, we would be able to remove stores to registers that are not live out from a block diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 19526fa91..4842b9a1c 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4219,7 +4219,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c for (AstExprFunction* expr : functions) compiler.compileFunction(expr, 0); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), + /*genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main, mainFlags); diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 693e0f870..5fba9fa30 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string value(lexer.current().data, lexer.current().length); + std::string value(lexer.current().data, lexer.current().getLength()); next(lexer); if (Error err = action(keys, value)) @@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string key(lexer.current().data, lexer.current().length); + std::string key(lexer.current().data, lexer.current().getLength()); next(lexer); keys.push_back(key); @@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action) lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString - ? std::string(lexer.current().data, lexer.current().length) + ? std::string(lexer.current().data, lexer.current().getLength()) : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); next(lexer); diff --git a/VM/include/lua.h b/VM/include/lua.h index 4876b933f..4ee9306ed 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +// alternative access for metatables already registered with luaL_newmetatable +LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx); +LUA_API void lua_getuserdatametatable(lua_State* L, int tag); + LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 58c767f16..87f85af85 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setuserdatametatable(lua_State* L, int tag, int idx) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, !L->global->udatamt[tag]); // reassignment not supported + StkId o = index2addr(L, idx); + api_check(L, ttistable(o)); + L->global->udatamt[tag] = hvalue(o); + L->top--; +} + +void lua_getuserdatametatable(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_threadbarrier(L); + + if (Table* h = L->global->udatamt[tag]) + { + sethvalue(L, L->top, h); + } + else + { + setnilvalue(L->top); + } + + api_incr_top(L); +} + void lua_setlightuserdataname(lua_State* L, int tag, const char* name) { api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index dbc1dd10b..6b7a9aa0d 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -210,7 +210,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) + { g->udatagc[i] = NULL; + g->udatamt[i] = NULL; + } for (i = 0; i < LUA_LUTAG_LIMIT; i++) g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 35e66471a..f8caa69bf 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -217,6 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 7ced52cfa..9c1fca9e1 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -342,6 +342,209 @@ void setupVectorHelpers(lua_State* L) lua_pop(L, 1); } +Vec2* lua_vec2_push(lua_State* L) +{ + Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2); + + lua_getuserdatametatable(L, kTagVec2); + lua_setmetatable(L, -2); + + return data; +} + +Vec2* lua_vec2_get(lua_State* L, int idx) +{ + Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2); + + if (a) + return a; + + luaL_typeerror(L, idx, "vec2"); +} + +static int lua_vec2(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = float(x); + data->y = float(y); + + return 1; +} + +static int lua_vec2_dot(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + lua_pushnumber(L, a->x * b->x + a->y * b->y); + return 1; +} + +static int lua_vec2_min(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = a->x < b->x ? a->x : b->x; + data->y = a->y < b->y ? a->y : b->y; + + return 1; +} + +static int lua_vec2_index(lua_State* L) +{ + Vec2* v = lua_vec2_get(L, 1); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "X") == 0) + { + lua_pushnumber(L, v->x); + return 1; + } + + if (strcmp(name, "Y") == 0) + { + lua_pushnumber(L, v->y); + return 1; + } + + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y)); + return 1; + } + + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y); + + Vec2* data = lua_vec2_push(L); + + data->x = v->x * invSqrt; + data->y = v->y * invSqrt; + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); +} + +static int lua_vec2_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vec2_dot(L); + + if (strcmp(str, "Min") == 0) + return lua_vec2_min(L); + } + + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); +} + +void setupUserdataHelpers(lua_State* L) +{ + // create metatable with all the metamethods + luaL_newmetatable(L, "vec2"); + luaL_getmetatable(L, "vec2"); + lua_pushvalue(L, -1); + lua_setuserdatametatable(L, kTagVec2, -1); + + lua_pushcfunction(L, lua_vec2_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, lua_vec2_namecall, nullptr); + lua_setfield(L, -2, "__namecall"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x + b->x; + data->y = a->y + b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__add"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x - b->x; + data->y = a->y - b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__sub"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x * b->x; + data->y = a->y * b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__mul"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x / b->x; + data->y = a->y / b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__div"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* data = lua_vec2_push(L); + + data->x = -a->x; + data->y = -a->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__unm"); + + lua_setreadonly(L, -1, true); + + // ctor + lua_pushcfunction(L, lua_vec2, "vec2"); + lua_setglobal(L, "vec2"); + + lua_pop(L, 1); +} + static void setupNativeHelpers(lua_State* L) { lua_pushcclosurek( @@ -1828,16 +2031,36 @@ TEST_CASE("UserdataApi") luaL_newmetatable(L, "udata2"); void* ud5 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + luaL_getmetatable(L, "udata1"); lua_setmetatable(L, -2); void* ud6 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + luaL_getmetatable(L, "udata2"); lua_setmetatable(L, -2); CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + // tagged user data with fast metatable access + luaL_newmetatable(L, "udata3"); + luaL_getmetatable(L, "udata3"); + lua_setuserdatametatable(L, 50, -1); + + luaL_newmetatable(L, "udata4"); + luaL_getmetatable(L, "udata4"); + lua_setuserdatametatable(L, 51, -1); + + void* ud7 = lua_newuserdatatagged(L, 16, 50); + lua_getuserdatametatable(L, 50); + lua_setmetatable(L, -2); + + void* ud8 = lua_newuserdatatagged(L, 16, 51); + lua_getuserdatametatable(L, 51); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata3") == ud7); + CHECK(luaL_checkudata(L, -1, "udata4") == ud8); + globalState.reset(); CHECK(dtorhits == 42); @@ -1911,7 +2134,6 @@ TEST_CASE("Iter") } const int kInt64Tag = 1; -static int gInt64MT = -1; static int64_t getInt64(lua_State* L, int idx) { @@ -1928,7 +2150,7 @@ static void pushInt64(lua_State* L, int64_t value) { void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); - lua_getref(L, gInt64MT); + luaL_getmetatable(L, "int64"); lua_setmetatable(L, -2); *static_cast(p) = value; @@ -1938,8 +2160,7 @@ TEST_CASE("Userdata") { runConformance("userdata.lua", [](lua_State* L) { // create metatable with all the metamethods - lua_newtable(L); - gInt64MT = lua_ref(L, -1); + luaL_newmetatable(L, "int64"); // __index lua_pushcfunction( @@ -2164,6 +2385,86 @@ TEST_CASE("NativeTypeAnnotations") }); } +TEST_CASE("NativeUserdata") +{ + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + nativeOpts.hooks.userdataAccess = userdataAccess; + nativeOpts.hooks.userdataMetamethod = userdataMetamethod; + nativeOpts.hooks.userdataNamecall = userdataNamecall; + + nativeOpts.userdataTypes = kUserdataRunTypes; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + + runConformance( + "native_userdata.lua", + [](lua_State* L) { + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + + setupVectorHelpers(L); + setupUserdataHelpers(L); + }, + nullptr, nullptr, &copts, false, &nativeOpts); +} + [[nodiscard]] static std::string makeHugeFunctionSource() { std::string source; diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h index d40508633..ab5b86d45 100644 --- a/tests/ConformanceIrHooks.h +++ b/tests/ConformanceIrHooks.h @@ -5,14 +5,44 @@ static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; +constexpr uint8_t kUserdataExtra = 0; +constexpr uint8_t kUserdataColor = 1; +constexpr uint8_t kUserdataVec2 = 2; +constexpr uint8_t kUserdataMat3 = 3; + +// Userdata tags can be different from userdata bytecode type indices +constexpr uint8_t kTagVec2 = 12; + +struct Vec2 +{ + float x; + float y; +}; + +inline bool compareMemberName(const char* member, size_t memberLength, const char* str) +{ + return memberLength == strlen(str) && strcmp(member, str) == 0; +} + +inline uint8_t typeToUserdataIndex(uint8_t type) +{ + // Underflow will push the type into a value that is not comparable to any kUserdata* constants + return type - LBC_TYPE_TAGGED_USERDATA_BASE; +} + +inline uint8_t userdataIndexToType(uint8_t userdataIndex) +{ + return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex; +} + inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -22,7 +52,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -42,7 +72,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si return true; } - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -72,10 +102,10 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) { - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0) + if (compareMemberName(member, memberLength, "Dot")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0) + if (compareMemberName(member, memberLength, "Cross")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -86,7 +116,7 @@ inline bool vectorNamecall( { using namespace Luau::CodeGen; - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -114,7 +144,7 @@ inline bool vectorNamecall( return true; } - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -151,3 +181,362 @@ inline bool vectorNamecall( return false; } + +inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + if (compareMemberName(member, memberLength, "R")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "G")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "B")) + return LBC_TYPE_NUMBER; + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Y")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + if (compareMemberName(member, memberLength, "Row1")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row2")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row3")) + return LBC_TYPE_VECTOR; + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataAccess( + Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Y")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} + +inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method) +{ + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + case Luau::CodeGen::HostMetamethod::Sub: + case Luau::CodeGen::HostMetamethod::Mul: + case Luau::CodeGen::HostMetamethod::Div: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + default: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs, + Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Mul: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::UNM_NUM, x); + IrOp my = build.inst(IrCmd::UNM_NUM, y); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + default: + break; + } + + return false; +} + +inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Min")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, + int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Min")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 5d7fedd8f..ecdb522c1 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -24,6 +24,8 @@ LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { @@ -34,6 +36,13 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.compilationOptions.hooks.vectorAccess = vectorAccess; options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + options.compilationOptions.hooks.userdataAccess = userdataAccess; + options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod; + options.compilationOptions.hooks.userdataNamecall = userdataNamecall; + // For IR, we don't care about assembly, but we want a stable target options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; @@ -1690,4 +1699,352 @@ end )"); } +TEST_CASE("CustomUserdataPropertyAccess") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(v: vec2) + return v.X + v.Y +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 12i, exit(0) + %8 = BUFFER_READF32 %6, 0i, tuserdata + %15 = BUFFER_READF32 %6, 4i, tuserdata + %24 = ADD_NUM %8, %15 + STORE_DOUBLE R1, %24 + STORE_TAG R1, tnumber + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataPropertyAccess2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return a.Row1 * a.Row2 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tvector, exit(4) + CHECK_TAG R3, tvector, exit(4) + %14 = LOAD_TVALUE R2 + %15 = LOAD_TVALUE R3 + %16 = MUL_VEC %14, %15 + %17 = TAG_VECTOR %16 + STORE_TVALUE R1, %17 + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataNamecall1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Dot(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MUL_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MUL_NUM %19, %20 + %22 = ADD_NUM %18, %21 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataNamecall2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Min(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MIN_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MIN_NUM %19, %20 + CHECK_GC + %23 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %23, 0i, %18, tuserdata + BUFFER_WRITEF32 %23, 4i, %21, tuserdata + STORE_POINTER R2, %23 + STORE_TAG R2, tuserdata + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3, b: mat3) + return a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: mat3 [argument] +; R1: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R2, R0, R1, 10i + INTERRUPT 1u + RETURN R2, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return -a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R1, R0, R0, 15i + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow3") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: sequence) + return #a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: userdata [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_LEN R1, R0 + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethod") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2, c: vec2) + return -c + a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +; R2: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + CHECK_TAG R2, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_POINTER R2 + CHECK_USERDATA_TAG %10, 12i, exit(0) + %12 = BUFFER_READF32 %10, 0i, tuserdata + %13 = BUFFER_READF32 %10, 4i, tuserdata + %14 = UNM_NUM %12 + %15 = UNM_NUM %13 + CHECK_GC + %17 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %17, 0i, %14, tuserdata + BUFFER_WRITEF32 %17, 4i, %15, tuserdata + STORE_POINTER R4, %17 + STORE_TAG R4, tuserdata + %26 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %26, 12i, exit(1) + %28 = LOAD_POINTER R1 + CHECK_USERDATA_TAG %28, 12i, exit(1) + %30 = BUFFER_READF32 %26, 0i, tuserdata + %31 = BUFFER_READF32 %28, 0i, tuserdata + %32 = MUL_NUM %30, %31 + %33 = BUFFER_READF32 %26, 4i, tuserdata + %34 = BUFFER_READF32 %28, 4i, tuserdata + %35 = MUL_NUM %33, %34 + %37 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %37, 0i, %32, tuserdata + BUFFER_WRITEF32 %37, 4i, %35, tuserdata + STORE_POINTER R5, %37 + STORE_TAG R5, tuserdata + %50 = BUFFER_READF32 %17, 0i, tuserdata + %51 = BUFFER_READF32 %37, 0i, tuserdata + %52 = ADD_NUM %50, %51 + %53 = BUFFER_READF32 %17, 4i, tuserdata + %54 = BUFFER_READF32 %37, 4i, tuserdata + %55 = ADD_NUM %53, %54 + %57 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %57, 0i, %52, tuserdata + BUFFER_WRITEF32 %57, 4i, %55, tuserdata + STORE_POINTER R3, %57 + STORE_TAG R3, tuserdata + INTERRUPT 3u + RETURN R3, 1i +)"); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 78d1389a6..e0716e4c5 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace") auto brokenInterpBegin = lexer.next(); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); - CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo")); CHECK_EQ(lexer.next().type, Lexeme::Name); auto interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); - CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); + CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar")); } TEST_CASE("string_interpolation_double_but_unmatched_brace") diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index e51fb0dfc..81a84722b 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,6 +15,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAttributeSyntax); + #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ do \ { \ @@ -68,6 +70,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -78,6 +81,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -85,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture } std::string definitions = R"BUILTIN_SRC( -declare function @checked abs(n: number): number -declare function @checked lower(s: string): string +@checked declare function abs(n: number): number +@checked declare function lower(s: string): string declare function cond() : boolean -declare function @checked contrived(n : Not) : number +@checked declare function contrived(n : Not) : number -- interesting types of things that we would like to mark as checked -declare function @checked onlyNums(...: number) : number -declare function @checked mixedArgs(x: string, ...: number) : number -declare function @checked optionalArg(x: string?) : number +@checked declare function onlyNums(...: number) : number +@checked declare function mixedArgs(x: string, ...: number) : number +@checked declare function optionalArg(x: string?) : number declare foo: { bar: @checked (number) -> number, } -declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number -declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number +@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number type DateTypeArg = { year: number, @@ -115,7 +119,7 @@ declare os : { time: @checked (time: DateTypeArg?) -> number } -declare function @checked require(target : any) : any +@checked declare function require(target : any) : any )BUILTIN_SRC"; }; @@ -558,6 +562,10 @@ local E = require(script.Parent.A) TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") { + ScopedFastFlag flags[] = { + {FFlag::LuauAttributeSyntax, true}, + }; + loadDefinition(R"( declare buffer: { create: @checked (size: number) -> buffer, diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 6b4bcf22b..8b2cc6bab 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauAttributeSyntax); LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); namespace @@ -3051,9 +3052,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; std::string src = R"BUILTIN_SRC( -declare function @checked abs(n: number): number +@checked declare function abs(n: number): number )BUILTIN_SRC"; ParseResult pr = tryParse(src, opts); @@ -3063,13 +3065,14 @@ declare function @checked abs(n: number): number AstStat* root = *(pr.root->body.data); auto func = root->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; const std::string src = R"BUILTIN_SRC( declare math : { @@ -3090,13 +3093,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") auto prop = *tbl->props.data; auto func = prop.type->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; ParseResult pr = tryParse(R"( local @checked = 3 @@ -3110,10 +3114,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @checked = 3 - declare function @checked abs(n: number): number + @checked declare function abs(n: number): number )", opts); LUAU_ASSERT(pr.errors.size() == 2); @@ -3125,9 +3130,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( - function @checked(x: number) : number + @checked function(x: number) : number end )", opts); @@ -3138,6 +3144,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @blah = 3 @@ -3190,4 +3197,300 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } +void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location) +{ + CHECK_EQ(attr->type, type); + CHECK_EQ(attr->location, location); +} + +void checkFirstErrorForAttributes(const std::vector& errors, const size_t minSize, const Location& location, const std::string& message) +{ + LUAU_ASSERT(minSize >= 1); + + CHECK_GE(errors.size(), minSize); + CHECK_EQ(errors[0].getLocation(), location); + CHECK_EQ(errors[0].getMessage(), message); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( +@checked +function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( + @checked +local function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatLocalFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12))); +} + +TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@ +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult pr1 = tryParse(R"( +@checked +if a<0 then a = 0 end)"); + checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead"); + + ParseResult pr2 = tryParse(R"( +local i = 1 +@checked +while a[i] do + print(a[i]) + i = i + 1 +end)"); + checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead"); + + ParseResult pr3 = tryParse(R"( +@checked +do + local a2 = 2*a + local d = sqrt(b^2 - 4*a*c) + x1 = (-b + d)/a2 + x2 = (-b - d)/a2 +end)"); + checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead"); + + ParseResult pr4 = tryParse(R"( +@checked +for i=1,10 do print(i) end +)"); + checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead"); + + ParseResult pr5 = tryParse(R"( +@checked +repeat + line = io.read() +until line ~= "" +)"); + checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead"); + + + ParseResult pr6 = tryParse(R"( +@checked +local x = 10 +)"); + checkFirstErrorForAttributes( + pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead"); + + ParseResult pr7 = tryParse(R"( +local i = 1 +while a[i] do + if a[i] == v then @checked break end + i = i + 1 +end +)"); + checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead"); + + + ParseResult pr8 = tryParse(R"( +function foo1 () @checked return 'a' end +)"); + checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +@checked declare function abs(n: number): number +)"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + auto func = root->as(); + LUAU_ASSERT(func != nullptr); + + CHECK(func->isCheckedFunction()); + + AstArray attributes = func->attributes; + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +declare bit32: { + band: @checked (...number) -> number +})"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + AstStatDeclareGlobal* glob = root->as(); + LUAU_ASSERT(glob); + + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + + LUAU_ASSERT(tbl->props.size == 1); + AstTableProp prop = tbl->props.data[0]; + + AstTypeFunction* func = prop.type->as(); + LUAU_ASSERT(func); + + AstArray attributes = func->attributes; + + CHECK_EQ(attributes.size, 1); + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr1 = tryParse(R"( +@checked declare foo: number + )", + opts); + + checkFirstErrorForAttributes( + pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead"); + + ParseResult pr2 = tryParse(R"( +@checked declare class Foo + prop: number + function method(self, foo: number): string +end)", + opts); + + checkFirstErrorForAttributes( + pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead"); + + ParseResult pr3 = tryParse(R"( +declare bit32: { + band: @checked number +})", + opts); + + checkFirstErrorForAttributes( + pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'"); +} + +TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @checked +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'"); +} + +TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @cool_attribute +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e2b3f9b79..d7cb225a0 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -13,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauSharedSelf); +LUAU_FASTFLAG(LuauAttributeSyntax); TEST_SUITE_BEGIN("ToString"); @@ -1010,10 +1011,11 @@ TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; auto _result = loadDefinition(R"( -declare function @checked abs(n: number) : number +@checked declare function abs(n: number) : number )"); auto result = check(Mode::Nonstrict, R"( diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1a7ef973c..ce6988aab 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -701,7 +701,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") REQUIRE(ut); REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, ut->options[0]); + CHECK_EQ(builtinTypes->nilType, follow(ut->options[0])); CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); } else @@ -1179,4 +1179,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") CHECK("any" == toString(requireTypeAtPosition({3, 25}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert") +{ + CheckResult result = check(R"( +local function foo(Instance) + for _, Child in next, Instance:GetChildren() do + end +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2c6136a49..516a761b1 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3153,7 +3153,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0])); else CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("boolean", toString(requireType("u"))); @@ -4439,7 +4439,13 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.lua new file mode 100644 index 000000000..b1b2a1033 --- /dev/null +++ b/tests/conformance/native_userdata.lua @@ -0,0 +1,42 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function realmad(a: vec2, b: vec2, c: vec2): vec2 + return -c + a * b; +end + +local function dm(s: vec2, t: vec2, u: vec2) + local x = s:Dot(t) + assert(x == 13) + + local t = u:Min(s) + assert(t.X == 5) + assert(t.Y == 4) +end + +local s: vec2 = vec2(5, 4) +local t: vec2 = vec2(1, 2) +local u: vec2 = vec2(10, 20) + +local x: vec2 = realmad(s, t, u) + +assert(x.X == -5) +assert(x.Y == -12) + +dm(s, t, u) + +local function mu(v: vec2) + assert(v.Magnitude == 2) + assert(v.Unit.X == 0) + assert(v.Unit.Y == 1) +end + +mu(vec2(0, 2)) + +return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index 7a214a32d..b2677bf4d 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_string_singletons AutocompleteTest.do_wrong_compatible_nonself_calls AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.string_singleton_in_if_statement2 AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_table @@ -33,6 +32,15 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type +ControlFlowAnalysis.for_record_do_if_not_x_break +ControlFlowAnalysis.for_record_do_if_not_x_continue +ControlFlowAnalysis.if_not_x_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_break_elif_not_y_continue +ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough +ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_return_elif_not_y_break DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right @@ -46,7 +54,6 @@ FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields -GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.dont_substitute_bound_types @@ -135,6 +142,7 @@ RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -278,7 +286,9 @@ TypeInferAnyError.can_subscript_any TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_any_pack +TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.for_in_loop_iterator_returns_any2 +TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive TypeInferClasses.class_type_mismatch_with_name_conflict @@ -337,6 +347,7 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.unifier_should_not_bind_free_types TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates @@ -363,7 +374,6 @@ TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union