Skip to content

Commit

Permalink
[clang] NFCI: use TemplateArgumentLoc for type-param DefaultArgument (#…
Browse files Browse the repository at this point in the history
…92854)

This is an enabler for a future patch.

This allows an type-parameter default argument to be set as an arbitrary
TemplateArgument, not just a type.
This allows template parameter packs to have default arguments in the
AST, even though the language proper doesn't support the syntax for it.

This will be used in a later patch which synthesizes template parameter
lists with arbitrary default arguments taken from template
specializations.

There are a few places we used SubsType, because we only had a type, now
we use SubstTemplateArgument.
SubstTemplateArgument was missing arguments for setting Instantiation
location and entity names.
Adding those is needed so we don't regress in diagnostics.
  • Loading branch information
mizvekov authored May 21, 2024
1 parent b908614 commit e42b799
Show file tree
Hide file tree
Showing 27 changed files with 144 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ AST_MATCHER(QualType, isEnableIf) {
AST_MATCHER_P(TemplateTypeParmDecl, hasDefaultArgument,
clang::ast_matchers::internal::Matcher<QualType>, TypeMatcher) {
return Node.hasDefaultArgument() &&
TypeMatcher.matches(Node.getDefaultArgument(), Finder, Builder);
TypeMatcher.matches(
Node.getDefaultArgument().getArgument().getAsType(), Finder,
Builder);
}
AST_MATCHER(TemplateDecl, hasAssociatedConstraints) {
return Node.hasAssociatedConstraints();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ namespace {
AST_MATCHER_P(TemplateTypeParmDecl, hasUnnamedDefaultArgument,
ast_matchers::internal::Matcher<TypeLoc>, InnerMatcher) {
if (Node.getIdentifier() != nullptr || !Node.hasDefaultArgument() ||
Node.getDefaultArgumentInfo() == nullptr)
Node.getDefaultArgument().getArgument().isNull())
return false;

TypeLoc DefaultArgTypeLoc = Node.getDefaultArgumentInfo()->getTypeLoc();
TypeLoc DefaultArgTypeLoc =
Node.getDefaultArgument().getTypeSourceInfo()->getTypeLoc();
return InnerMatcher.matches(DefaultArgTypeLoc, Finder, Builder);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
dyn_cast<TemplateTypeParmDecl>(LastParam)) {
if (LastTemplateParam->hasDefaultArgument() &&
LastTemplateParam->getIdentifier() == nullptr) {
return {matchEnableIfSpecialization(
LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
LastTemplateParam};
return {
matchEnableIfSpecialization(LastTemplateParam->getDefaultArgument()
.getTypeSourceInfo()
->getTypeLoc()),
LastTemplateParam};
}
}
return {};
Expand Down
8 changes: 6 additions & 2 deletions clang-tools-extra/clangd/Hover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,12 @@ fetchTemplateParameters(const TemplateParameterList *Params,
if (!TTP->getName().empty())
P.Name = TTP->getNameAsString();

if (TTP->hasDefaultArgument())
P.Default = TTP->getDefaultArgument().getAsString(PP);
if (TTP->hasDefaultArgument()) {
P.Default.emplace();
llvm::raw_string_ostream Out(*P.Default);
TTP->getDefaultArgument().getArgument().print(PP, Out,
/*IncludeType=*/false);
}
} else if (const auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(Param)) {
P.Type = printType(NTTP, PP);

Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ class ASTNodeTraverser
if (const auto *TC = D->getTypeConstraint())
Visit(TC->getImmediatelyDeclaredConstraint());
if (D->hasDefaultArgument())
Visit(D->getDefaultArgument(), SourceRange(),
Visit(D->getDefaultArgument().getArgument(), SourceRange(),
D->getDefaultArgStorage().getInheritedFrom(),
D->defaultArgumentWasInherited() ? "inherited from" : "previous");
}
Expand Down
17 changes: 6 additions & 11 deletions clang/include/clang/AST/DeclTemplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ class TemplateTypeParmDecl final : public TypeDecl,

/// The default template argument, if any.
using DefArgStorage =
DefaultArgStorage<TemplateTypeParmDecl, TypeSourceInfo *>;
DefaultArgStorage<TemplateTypeParmDecl, TemplateArgumentLoc *>;
DefArgStorage DefaultArgument;

TemplateTypeParmDecl(DeclContext *DC, SourceLocation KeyLoc,
Expand Down Expand Up @@ -1225,13 +1225,9 @@ class TemplateTypeParmDecl final : public TypeDecl,
bool hasDefaultArgument() const { return DefaultArgument.isSet(); }

/// Retrieve the default argument, if any.
QualType getDefaultArgument() const {
return DefaultArgument.get()->getType();
}

/// Retrieves the default argument's source information, if any.
TypeSourceInfo *getDefaultArgumentInfo() const {
return DefaultArgument.get();
const TemplateArgumentLoc &getDefaultArgument() const {
static const TemplateArgumentLoc NoneLoc;
return DefaultArgument.isSet() ? *DefaultArgument.get() : NoneLoc;
}

/// Retrieves the location of the default argument declaration.
Expand All @@ -1244,9 +1240,8 @@ class TemplateTypeParmDecl final : public TypeDecl,
}

/// Set the default argument for this template parameter.
void setDefaultArgument(TypeSourceInfo *DefArg) {
DefaultArgument.set(DefArg);
}
void setDefaultArgument(const ASTContext &C,
const TemplateArgumentLoc &DefArg);

/// Set that this default argument was inherited from another
/// parameter.
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,7 @@ DEF_TRAVERSE_DECL(TemplateTypeParmDecl, {
TRY_TO(TraverseType(QualType(D->getTypeForDecl(), 0)));
TRY_TO(TraverseTemplateTypeParamDeclConstraints(D));
if (D->hasDefaultArgument() && !D->defaultArgumentWasInherited())
TRY_TO(TraverseTypeLoc(D->getDefaultArgumentInfo()->getTypeLoc()));
TRY_TO(TraverseTemplateArgumentLoc(D->getDefaultArgument()));
})

DEF_TRAVERSE_DECL(TypedefDecl, {
Expand Down
4 changes: 3 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10067,7 +10067,9 @@ class Sema final : public SemaBase {

bool SubstTemplateArgument(const TemplateArgumentLoc &Input,
const MultiLevelTemplateArgumentList &TemplateArgs,
TemplateArgumentLoc &Output);
TemplateArgumentLoc &Output,
SourceLocation Loc = {},
const DeclarationName &Entity = {});
bool
SubstTemplateArguments(ArrayRef<TemplateArgumentLoc> Args,
const MultiLevelTemplateArgumentList &TemplateArgs,
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6494,7 +6494,8 @@ bool ASTContext::isSameDefaultTemplateArgument(const NamedDecl *X,
if (!TTPX->hasDefaultArgument() || !TTPY->hasDefaultArgument())
return false;

return hasSameType(TTPX->getDefaultArgument(), TTPY->getDefaultArgument());
return hasSameType(TTPX->getDefaultArgument().getArgument().getAsType(),
TTPY->getDefaultArgument().getArgument().getAsType());
}

if (auto *NTTPX = dyn_cast<NonTypeTemplateParmDecl>(X)) {
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/AST/ASTImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5917,11 +5917,11 @@ ASTNodeImporter::VisitTemplateTypeParmDecl(TemplateTypeParmDecl *D) {
}

if (D->hasDefaultArgument()) {
Expected<TypeSourceInfo *> ToDefaultArgOrErr =
import(D->getDefaultArgumentInfo());
Expected<TemplateArgumentLoc> ToDefaultArgOrErr =
import(D->getDefaultArgument());
if (!ToDefaultArgOrErr)
return ToDefaultArgOrErr.takeError();
ToD->setDefaultArgument(*ToDefaultArgOrErr);
ToD->setDefaultArgument(ToD->getASTContext(), *ToDefaultArgOrErr);
}

return ToD;
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/AST/DeclPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,8 @@ void DeclPrinter::VisitTemplateTypeParmDecl(const TemplateTypeParmDecl *TTP) {

if (TTP->hasDefaultArgument()) {
Out << " = ";
Out << TTP->getDefaultArgument().getAsString(Policy);
TTP->getDefaultArgument().getArgument().print(Policy, Out,
/*IncludeType=*/false);
}
}

Expand Down
17 changes: 12 additions & 5 deletions clang/lib/AST/DeclTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,23 +669,30 @@ TemplateTypeParmDecl::CreateDeserialized(const ASTContext &C, GlobalDeclID ID,
}

SourceLocation TemplateTypeParmDecl::getDefaultArgumentLoc() const {
return hasDefaultArgument()
? getDefaultArgumentInfo()->getTypeLoc().getBeginLoc()
: SourceLocation();
return hasDefaultArgument() ? getDefaultArgument().getLocation()
: SourceLocation();
}

SourceRange TemplateTypeParmDecl::getSourceRange() const {
if (hasDefaultArgument() && !defaultArgumentWasInherited())
return SourceRange(getBeginLoc(),
getDefaultArgumentInfo()->getTypeLoc().getEndLoc());
getDefaultArgument().getSourceRange().getEnd());
// TypeDecl::getSourceRange returns a range containing name location, which is
// wrong for unnamed template parameters. e.g:
// it will return <[[typename>]] instead of <[[typename]]>
else if (getDeclName().isEmpty())
if (getDeclName().isEmpty())
return SourceRange(getBeginLoc());
return TypeDecl::getSourceRange();
}

void TemplateTypeParmDecl::setDefaultArgument(
const ASTContext &C, const TemplateArgumentLoc &DefArg) {
if (DefArg.getArgument().isNull())
DefaultArgument.set(nullptr);
else
DefaultArgument.set(new (C) TemplateArgumentLoc(DefArg));
}

unsigned TemplateTypeParmDecl::getDepth() const {
return getTypeForDecl()->castAs<TemplateTypeParmType>()->getDepth();
}
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/JSONNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ void JSONNodeDumper::VisitTemplateTypeParmDecl(const TemplateTypeParmDecl *D) {

if (D->hasDefaultArgument())
JOS.attributeObject("defaultArg", [=] {
Visit(D->getDefaultArgument(), SourceRange(),
Visit(D->getDefaultArgument().getArgument(), SourceRange(),
D->getDefaultArgStorage().getInheritedFrom(),
D->defaultArgumentWasInherited() ? "inherited from" : "previous");
});
Expand Down
12 changes: 7 additions & 5 deletions clang/lib/AST/ODRDiagsEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,13 +1409,15 @@ bool ODRDiagsEmitter::diagnoseMismatch(
}

if (HasFirstDefaultArgument && HasSecondDefaultArgument) {
QualType FirstType = FirstTTPD->getDefaultArgument();
QualType SecondType = SecondTTPD->getDefaultArgument();
if (computeODRHash(FirstType) != computeODRHash(SecondType)) {
TemplateArgument FirstTA =
FirstTTPD->getDefaultArgument().getArgument();
TemplateArgument SecondTA =
SecondTTPD->getDefaultArgument().getArgument();
if (computeODRHash(FirstTA) != computeODRHash(SecondTA)) {
DiagTemplateError(FunctionTemplateParameterDifferentDefaultArgument)
<< (i + 1) << FirstType;
<< (i + 1) << FirstTA;
DiagTemplateNote(FunctionTemplateParameterDifferentDefaultArgument)
<< (i + 1) << SecondType;
<< (i + 1) << SecondTA;
return true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/ODRHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ class ODRDeclVisitor : public ConstDeclVisitor<ODRDeclVisitor> {
D->hasDefaultArgument() && !D->defaultArgumentWasInherited();
Hash.AddBoolean(hasDefaultArgument);
if (hasDefaultArgument) {
AddTemplateArgument(D->getDefaultArgument());
AddTemplateArgument(D->getDefaultArgument().getArgument());
}
Hash.AddBoolean(D->isParameterPack());

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2273,8 +2273,8 @@ bool clang::isSubstitutedDefaultArgument(ASTContext &Ctx, TemplateArgument Arg,

if (auto *TTPD = dyn_cast<TemplateTypeParmDecl>(Param)) {
return TTPD->hasDefaultArgument() &&
isSubstitutedTemplateArgument(Ctx, Arg, TTPD->getDefaultArgument(),
Args, Depth);
isSubstitutedTemplateArgument(
Ctx, Arg, TTPD->getDefaultArgument().getArgument(), Args, Depth);
} else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) {
return TTPD->hasDefaultArgument() &&
isSubstitutedTemplateArgument(
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/ExtractAPI/DeclarationFragments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,11 +999,11 @@ DeclarationFragmentsBuilder::getFragmentsForTemplateParameters(
DeclarationFragments::FragmentKind::GenericParameter);

if (TemplateParam->hasDefaultArgument()) {
DeclarationFragments After;
const auto Default = TemplateParam->getDefaultArgument();
Fragments.append(" = ", DeclarationFragments::FragmentKind::Text)
.append(getFragmentsForType(TemplateParam->getDefaultArgument(),
TemplateParam->getASTContext(), After));
Fragments.append(std::move(After));
.append(getFragmentsForTemplateArguments(
{Default.getArgument()}, TemplateParam->getASTContext(),
{Default}));
}
} else if (const auto *NTP =
dyn_cast<NonTypeTemplateParmDecl>(ParameterArray[i])) {
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Index/IndexDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ class IndexingDeclVisitor : public ConstDeclVisitor<IndexingDeclVisitor, bool> {
IndexCtx.handleDecl(TP);
if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(TP)) {
if (TTP->hasDefaultArgument())
IndexCtx.indexTypeSourceInfo(TTP->getDefaultArgumentInfo(), Parent);
handleTemplateArgumentLoc(TTP->getDefaultArgument(), Parent,
TP->getLexicalDeclContext());
if (auto *C = TTP->getTypeConstraint())
IndexCtx.handleReference(C->getNamedConcept(), C->getConceptNameLoc(),
Parent, TTP->getLexicalDeclContext());
Expand Down
48 changes: 28 additions & 20 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,18 @@ struct BuiltinTypeDeclBuilder {
return *this;
}

TemplateParameterListBuilder addTemplateArgumentList();
BuiltinTypeDeclBuilder &addSimpleTemplateParams(ArrayRef<StringRef> Names);
TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
ArrayRef<StringRef> Names);
};

struct TemplateParameterListBuilder {
BuiltinTypeDeclBuilder &Builder;
ASTContext &AST;
Sema &S;
llvm::SmallVector<NamedDecl *> Params;

TemplateParameterListBuilder(BuiltinTypeDeclBuilder &RB)
: Builder(RB), AST(RB.Record->getASTContext()) {}
TemplateParameterListBuilder(Sema &S, BuiltinTypeDeclBuilder &RB)
: Builder(RB), S(S) {}

~TemplateParameterListBuilder() { finalizeTemplateArgs(); }

Expand All @@ -328,12 +329,15 @@ struct TemplateParameterListBuilder {
return *this;
unsigned Position = static_cast<unsigned>(Params.size());
auto *Decl = TemplateTypeParmDecl::Create(
AST, Builder.Record->getDeclContext(), SourceLocation(),
S.Context, Builder.Record->getDeclContext(), SourceLocation(),
SourceLocation(), /* TemplateDepth */ 0, Position,
&AST.Idents.get(Name, tok::TokenKind::identifier), /* Typename */ false,
&S.Context.Idents.get(Name, tok::TokenKind::identifier),
/* Typename */ false,
/* ParameterPack */ false);
if (!DefaultValue.isNull())
Decl->setDefaultArgument(AST.getTrivialTypeSourceInfo(DefaultValue));
Decl->setDefaultArgument(
S.Context, S.getTrivialTemplateArgumentLoc(DefaultValue, QualType(),
SourceLocation()));

Params.emplace_back(Decl);
return *this;
Expand All @@ -342,11 +346,11 @@ struct TemplateParameterListBuilder {
BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
if (Params.empty())
return Builder;
auto *ParamList =
TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
Params, SourceLocation(), nullptr);
auto *ParamList = TemplateParameterList::Create(S.Context, SourceLocation(),
SourceLocation(), Params,
SourceLocation(), nullptr);
Builder.Template = ClassTemplateDecl::Create(
AST, Builder.Record->getDeclContext(), SourceLocation(),
S.Context, Builder.Record->getDeclContext(), SourceLocation(),
DeclarationName(Builder.Record->getIdentifier()), ParamList,
Builder.Record);
Builder.Record->setDescribedClassTemplate(Builder.Template);
Expand All @@ -359,20 +363,22 @@ struct TemplateParameterListBuilder {
Params.clear();

QualType T = Builder.Template->getInjectedClassNameSpecialization();
T = AST.getInjectedClassNameType(Builder.Record, T);
T = S.Context.getInjectedClassNameType(Builder.Record, T);

return Builder;
}
};
} // namespace

TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
return TemplateParameterListBuilder(*this);
TemplateParameterListBuilder
BuiltinTypeDeclBuilder::addTemplateArgumentList(Sema &S) {
return TemplateParameterListBuilder(S, *this);
}

BuiltinTypeDeclBuilder &
BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names) {
TemplateParameterListBuilder Builder = this->addTemplateArgumentList();
BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
ArrayRef<StringRef> Names) {
TemplateParameterListBuilder Builder = this->addTemplateArgumentList(S);
for (StringRef Name : Names)
Builder.addTypeParameter(Name);
return Builder.finalizeTemplateArgs();
Expand Down Expand Up @@ -426,7 +432,9 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
auto *TypeParam = TemplateTypeParmDecl::Create(
AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
&AST.Idents.get("element", tok::TokenKind::identifier), false, false);
TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
TypeParam->setDefaultArgument(
AST, SemaPtr->getTrivialTemplateArgumentLoc(
TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));

TemplateParams.emplace_back(TypeParam);

Expand Down Expand Up @@ -492,7 +500,7 @@ static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
CXXRecordDecl *Decl;
Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
.addSimpleTemplateParams({"element_type"})
.addSimpleTemplateParams(*SemaPtr, {"element_type"})
.Record;
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
Expand All @@ -503,7 +511,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {

Decl =
BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
.addSimpleTemplateParams({"element_type"})
.addSimpleTemplateParams(*SemaPtr, {"element_type"})
.Record;
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
Expand Down
Loading

0 comments on commit e42b799

Please sign in to comment.