Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for @unknown default #14382

Merged
merged 11 commits into from
Apr 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,15 @@ ERROR(type_mismatch_multiple_pattern_list,none,
ERROR(type_mismatch_fallthrough_pattern_list,none,
"pattern variable bound to type %0, fallthrough case bound to type %1", (Type, Type))

ERROR(unknown_case_must_be_catchall,none,
"'@unknown' is only supported for catch-all cases (\"case _\")", ())
ERROR(unknown_case_where_clause,none,
"'where' cannot be used with '@unknown'", ())
ERROR(unknown_case_multiple_patterns,none,
"'@unknown' cannot be applied to multiple patterns", ())
ERROR(unknown_case_must_be_last,none,
"'@unknown' can only be applied to the last case in a switch", ())

WARNING(where_on_one_item, none,
"'where' only applies to the second pattern match in this case", ())

Expand Down Expand Up @@ -3892,6 +3901,8 @@ NOTE(missing_several_cases,none,
"do you want to add "
"%select{missing cases|a default clause}0"
"?", (bool))
NOTE(missing_unknown_case,none,
"handle unknown values using \"@unknown default\"", ())

NOTE(missing_particular_case,none,
"add missing case: '%0'", (StringRef))
Expand All @@ -3903,9 +3914,7 @@ WARNING(redundant_particular_literal_case,none,
NOTE(redundant_particular_literal_case_here,none,
"first occurrence of identical literal pattern is here", ())

// HACK: Downgrades the above to warnings if any of the cases is marked
// @_downgrade_exhaustivity_check.
WARNING(non_exhaustive_switch_warn_swift3,none, "switch must be exhaustive", ())
WARNING(non_exhaustive_switch_warn,none, "switch must be exhaustive", ())

#ifndef DIAG_NO_UNDEF
# if defined(DIAG)
Expand Down
66 changes: 53 additions & 13 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/Support/TrailingObjects.h"

namespace swift {
class AnyPattern;
class ASTContext;
class ASTWalker;
class Decl;
Expand Down Expand Up @@ -814,17 +815,40 @@ class ForEachStmt : public LabeledStmt {

/// A pattern and an optional guard expression used in a 'case' statement.
class CaseLabelItem {
enum class Kind {
/// A normal pattern
Normal = 0,
/// `default`
Default,
};

Pattern *CasePattern;
SourceLoc WhereLoc;
llvm::PointerIntPair<Expr *, 1, bool> GuardExprAndIsDefault;
llvm::PointerIntPair<Expr *, 1, Kind> GuardExprAndKind;

CaseLabelItem(Kind kind, Pattern *casePattern, SourceLoc whereLoc,
Expr *guardExpr)
: CasePattern(casePattern), WhereLoc(whereLoc),
GuardExprAndKind(guardExpr, kind) {}

public:
CaseLabelItem(const CaseLabelItem &) = default;

CaseLabelItem(bool IsDefault, Pattern *CasePattern, SourceLoc WhereLoc,
Expr *GuardExpr)
: CasePattern(CasePattern), WhereLoc(WhereLoc),
GuardExprAndIsDefault(GuardExpr, IsDefault) {}
CaseLabelItem(Pattern *casePattern, SourceLoc whereLoc, Expr *guardExpr)
: CaseLabelItem(Kind::Normal, casePattern, whereLoc, guardExpr) {}
explicit CaseLabelItem(Pattern *casePattern)
: CaseLabelItem(casePattern, SourceLoc(), nullptr) {}

static CaseLabelItem getDefault(AnyPattern *pattern,
SourceLoc whereLoc,
Expr *guardExpr) {
assert(pattern);
return CaseLabelItem(Kind::Default, reinterpret_cast<Pattern *>(pattern),
whereLoc, guardExpr);
}
static CaseLabelItem getDefault(AnyPattern *pattern) {
return getDefault(pattern, SourceLoc(), nullptr);
}

SourceLoc getWhereLoc() const { return WhereLoc; }

Expand All @@ -838,14 +862,16 @@ class CaseLabelItem {

/// Return the guard expression if present, or null if the case label has
/// no guard.
Expr *getGuardExpr() { return GuardExprAndIsDefault.getPointer(); }
Expr *getGuardExpr() { return GuardExprAndKind.getPointer(); }
const Expr *getGuardExpr() const {
return GuardExprAndIsDefault.getPointer();
return GuardExprAndKind.getPointer();
}
void setGuardExpr(Expr *e) { GuardExprAndIsDefault.setPointer(e); }
void setGuardExpr(Expr *e) { GuardExprAndKind.setPointer(e); }

/// Returns true if this is syntactically a 'default' label.
bool isDefault() const { return GuardExprAndIsDefault.getInt(); }
bool isDefault() const {
return GuardExprAndKind.getInt() == Kind::Default;
}
};

/// A 'case' or 'default' block of a switch statement. Only valid as the
Expand All @@ -865,19 +891,21 @@ class CaseStmt final : public Stmt,
private llvm::TrailingObjects<CaseStmt, CaseLabelItem> {
friend TrailingObjects;

SourceLoc UnknownAttrLoc;
SourceLoc CaseLoc;
SourceLoc ColonLoc;

llvm::PointerIntPair<Stmt *, 1, bool> BodyAndHasBoundDecls;

CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit);
bool HasBoundDecls, SourceLoc UnknownAttrLoc, SourceLoc ColonLoc,
Stmt *Body, Optional<bool> Implicit);

public:
static CaseStmt *create(ASTContext &C, SourceLoc CaseLoc,
ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc ColonLoc, Stmt *Body,
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit = None);

ArrayRef<CaseLabelItem> getCaseLabelItems() const {
Expand All @@ -896,14 +924,26 @@ class CaseStmt final : public Stmt,
/// Get the source location of the 'case' or 'default' of the first label.
SourceLoc getLoc() const { return CaseLoc; }

SourceLoc getStartLoc() const { return getLoc(); }
SourceLoc getStartLoc() const {
if (UnknownAttrLoc.isValid())
return UnknownAttrLoc;
return getLoc();
}
SourceLoc getEndLoc() const { return getBody()->getEndLoc(); }
SourceRange getLabelItemsRange() const {
return ColonLoc.isValid() ? SourceRange(getLoc(), ColonLoc) : getSourceRange();
}

bool isDefault() { return getCaseLabelItems()[0].isDefault(); }

bool hasUnknownAttr() const {
// Note: This representation doesn't allow for synthesized @unknown cases.
// However, that's probably sensible; the purpose of @unknown is for
// diagnosing otherwise-non-exhaustive switches, and the user can't edit
// a synthesized case.
return UnknownAttrLoc.isValid();
}

static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Case; }
};

Expand Down
1 change: 1 addition & 0 deletions include/swift/Migrator/FixitFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ struct FixitFilter {
Info.ID == diag::objc_inference_swift3_objc_derived.ID ||
Info.ID == diag::missing_several_cases.ID ||
Info.ID == diag::missing_particular_case.ID ||
Info.ID == diag::missing_unknown_case.ID ||
Info.ID == diag::paren_void_probably_void.ID ||
Info.ID == diag::make_decl_objc.ID ||
Info.ID == diag::optional_req_nonobjc_near_match_add_objc.ID)
Expand Down
4 changes: 4 additions & 0 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1556,11 +1556,15 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
}
void visitCaseStmt(CaseStmt *S) {
printCommon(S, "case_stmt");
if (S->hasUnknownAttr())
OS << " @unknown";
for (const auto &LabelItem : S->getCaseLabelItems()) {
OS << '\n';
OS.indent(Indent + 2);
PrintWithColorRAII(OS, ParenthesisColor) << '(';
PrintWithColorRAII(OS, StmtColor) << "case_label_item";
if (LabelItem.isDefault())
OS << " default";
if (auto *CasePattern = LabelItem.getPattern()) {
OS << '\n';
printRec(CasePattern);
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2965,6 +2965,9 @@ void PrintAST::visitSwitchStmt(SwitchStmt *stmt) {
}

void PrintAST::visitCaseStmt(CaseStmt *CS) {
if (CS->hasUnknownAttr())
Printer << "@unknown ";

if (CS->isDefault()) {
Printer << tok::kw_default;
} else {
Expand Down
13 changes: 7 additions & 6 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ SourceLoc CaseLabelItem::getEndLoc() const {
}

CaseStmt::CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit)
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body, Optional<bool> Implicit)
: Stmt(StmtKind::Case, getDefaultImplicitFlag(Implicit, CaseLoc)),
CaseLoc(CaseLoc), ColonLoc(ColonLoc),
UnknownAttrLoc(UnknownAttrLoc), CaseLoc(CaseLoc), ColonLoc(ColonLoc),
BodyAndHasBoundDecls(Body, HasBoundDecls) {
Bits.CaseStmt.NumPatterns = CaseLabelItems.size();
assert(Bits.CaseStmt.NumPatterns > 0 &&
Expand All @@ -387,12 +387,13 @@ CaseStmt::CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,

CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
ArrayRef<CaseLabelItem> CaseLabelItems,
bool HasBoundDecls, SourceLoc ColonLoc, Stmt *Body,
bool HasBoundDecls, SourceLoc UnknownAttrLoc,
SourceLoc ColonLoc, Stmt *Body,
Optional<bool> Implicit) {
void *Mem = C.Allocate(totalSizeToAlloc<CaseLabelItem>(CaseLabelItems.size()),
alignof(CaseStmt));
return ::new (Mem) CaseStmt(CaseLoc, CaseLabelItems, HasBoundDecls, ColonLoc,
Body, Implicit);
return ::new (Mem) CaseStmt(CaseLoc, CaseLabelItems, HasBoundDecls,
UnknownAttrLoc, ColonLoc, Body, Implicit);
}

SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
Expand Down
8 changes: 7 additions & 1 deletion lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2228,7 +2228,13 @@ bool Parser::isStartOfDecl() {
if (Tok.is(tok::kw_init)) {
return !isa<ConstructorDecl>(CurDeclContext);
}


// Similarly, when 'case' appears inside a function, it's probably a switch
// case, not an enum case declaration.
if (Tok.is(tok::kw_case)) {
return !isa<AbstractFunctionDecl>(CurDeclContext);
}

// The protocol keyword needs more checking to reject "protocol<Int>".
if (Tok.is(tok::kw_protocol)) {
const Token &Tok2 = peekToken();
Expand Down
77 changes: 63 additions & 14 deletions lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,51 @@ ParserStatus Parser::parseExprOrStmt(ASTNode &Result) {
return ResultExpr;
}

/// Returns whether the parser's current position is the start of a switch case,
/// given that we're in the middle of a switch already.
static bool isAtStartOfSwitchCase(Parser &parser,
bool needsToBacktrack = true) {
Optional<Parser::BacktrackingScope> backtrack;

// Check for and consume attributes. The only valid attribute is `@unknown`
// but that's a semantic restriction.
while (parser.Tok.is(tok::at_sign)) {
if (!parser.peekToken().is(tok::identifier))
return false;

if (needsToBacktrack && !backtrack)
backtrack.emplace(parser);

parser.consumeToken(tok::at_sign);
parser.consumeIdentifier();
if (parser.Tok.is(tok::l_paren))
parser.skipSingle();
}

return parser.Tok.isAny(tok::kw_case, tok::kw_default);
}

bool Parser::isTerminatorForBraceItemListKind(BraceItemListKind Kind,
ArrayRef<ASTNode> ParsedDecls) {
switch (Kind) {
case BraceItemListKind::Brace:
return false;
case BraceItemListKind::Case:
case BraceItemListKind::Case: {
if (Tok.is(tok::pound_if)) {
// Backtracking scopes are expensive, so avoid setting one up if possible.
Parser::BacktrackingScope Backtrack(*this);
// '#if' here could be to guard 'case:' or statements in cases.
// If the next non-directive line starts with 'case' or 'default', it is
// for 'case's.
Parser::BacktrackingScope Backtrack(*this);
do {
consumeToken();
while (!Tok.isAtStartOfLine() && Tok.isNot(tok::eof))
skipSingle();
} while (Tok.isAny(tok::pound_if, tok::pound_elseif, tok::pound_else));
return Tok.isAny(tok::kw_case, tok::kw_default);
return isAtStartOfSwitchCase(*this, /*needsToBacktrack*/false);
}
return Tok.isAny(tok::kw_case, tok::kw_default);
return isAtStartOfSwitchCase(*this);
}
case BraceItemListKind::TopLevelCode:
// When parsing the top level executable code for a module, if we parsed
// some executable code, then we're done. We want to process (name bind,
Expand Down Expand Up @@ -2025,7 +2051,7 @@ Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
ParserStatus Status;
while (Tok.isNot(tok::r_brace, tok::eof,
tok::pound_endif, tok::pound_elseif, tok::pound_else)) {
if (Tok.isAny(tok::kw_case, tok::kw_default)) {
if (isAtStartOfSwitchCase(*this)) {
ParserResult<CaseStmt> Case = parseStmtCase(IsActive);
Status |= Case;
if (Case.isNonNull())
Expand Down Expand Up @@ -2094,12 +2120,11 @@ static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
parseGuardedPattern(P, PatternResult, Status, BoundDecls,
GuardedPatternContext::Case, isFirst);
LabelItems.push_back(
CaseLabelItem(/*IsDefault=*/false, PatternResult.ThePattern,
PatternResult.WhereLoc, PatternResult.Guard));
CaseLabelItem(PatternResult.ThePattern, PatternResult.WhereLoc,
PatternResult.Guard));
isFirst = false;
if (P.consumeIf(tok::comma))
continue;
break;
if (!P.consumeIf(tok::comma))
break;
}
}

Expand Down Expand Up @@ -2144,7 +2169,7 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
// Create an implicit AnyPattern to represent the default match.
auto Any = new (P.Context) AnyPattern(CaseLoc);
LabelItems.push_back(
CaseLabelItem(/*IsDefault=*/true, Any, WhereLoc, Guard.getPtrOrNull()));
CaseLabelItem::getDefault(Any, WhereLoc, Guard.getPtrOrNull()));

return Status;
}
Expand All @@ -2159,6 +2184,30 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
SmallVector<CaseLabelItem, 2> CaseLabelItems;
SmallVector<VarDecl *, 4> BoundDecls;

SourceLoc UnknownAttrLoc;
while (Tok.is(tok::at_sign)) {
SyntaxParsingContext AttrCtx(SyntaxContext, SyntaxKind::Attribute);
UnknownAttrLoc = consumeToken(tok::at_sign);

if (Tok.isContextualKeyword("unknown")) {
consumeIdentifier();

// Form an empty TokenList for the arguments of the 'Attribute' Syntax
// node.
SyntaxParsingContext(SyntaxContext, SyntaxKind::TokenList);
} else {
UnknownAttrLoc = SourceLoc();

diagnose(Tok, diag::unknown_attribute, Tok.getText());
consumeIdentifier();

if (Tok.is(tok::l_paren)) {
SyntaxParsingContext Args(SyntaxContext, SyntaxKind::TokenList);
skipSingle();
}
}
}

SourceLoc CaseLoc;
SourceLoc ColonLoc;
if (Tok.is(tok::kw_case)) {
Expand All @@ -2173,8 +2222,7 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
SmallVector<ASTNode, 8> BodyItems;

SourceLoc StartOfBody = Tok.getLoc();
if (Tok.isNot(tok::kw_case) && Tok.isNot(tok::kw_default) &&
Tok.isNot(tok::r_brace)) {
if (Tok.isNot(tok::r_brace) && !isAtStartOfSwitchCase(*this)) {
Status |= parseBraceItems(BodyItems, BraceItemListKind::Case);
} else if (Status.isSuccess()) {
diagnose(CaseLoc, diag::case_stmt_without_body,
Expand All @@ -2193,5 +2241,6 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {

return makeParserResult(
Status, CaseStmt::create(Context, CaseLoc, CaseLabelItems,
!BoundDecls.empty(), ColonLoc, Body));
!BoundDecls.empty(), UnknownAttrLoc, ColonLoc,
Body));
}
Loading