Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use parent to determine whether an expression is target of an invocation #3340

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (r *REPL) Accept(code []byte, eval bool) (inputIsComplete bool, err error)
var expressionType sema.Type
expressionStatement, isExpression := statement.(*ast.ExpressionStatement)
if isExpression {
expressionType = r.checker.VisitExpression(expressionStatement.Expression, nil)
expressionType = r.checker.VisitExpression(expressionStatement.Expression, expressionStatement, nil)
if !eval && expressionType != sema.InvalidType {
r.onExpressionType(expressionType)
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_array_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (checker *Checker) VisitArrayExpression(expression *ast.ArrayExpression) Ty
argumentTypes = make([]Type, elementCount)

for i, value := range expression.Values {
valueType := checker.VisitExpression(value, elementType)
valueType := checker.VisitExpression(value, expression, elementType)
turbolent marked this conversation as resolved.
Show resolved Hide resolved

argumentTypes[i] = valueType

Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (checker *Checker) checkAssignment(
if checker.accessedSelfMember(target) == nil {
checkValue = checker.VisitExpressionWithReferenceCheck
}
valueType = checkValue(value, targetType)
valueType = checkValue(value, assignment, targetType)

// NOTE: Visiting the `value` checks the compatibility between value and target types.
// Check for the *target* type, so that assignment using non-resource typed value (e.g. `nil`)
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_attach_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (checker *Checker) VisitAttachExpression(expression *ast.AttachExpression)
attachment := expression.Attachment
baseExpression := expression.Base

baseType := checker.VisitExpression(baseExpression, checker.expectedType)
baseType := checker.VisitExpression(baseExpression, expression, checker.expectedType)
attachmentType := checker.checkInvocationExpression(attachment)

if attachmentType.IsInvalidType() || baseType.IsInvalidType() {
Expand Down
21 changes: 18 additions & 3 deletions runtime/sema/check_binary_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
// Visit the expression, with contextually expected type. Use the expected type
// only for inferring wherever possible, but do not check for compatibility.
// Compatibility is checked separately for each operand kind.
leftType = checker.VisitExpressionWithForceType(expression.Left, expectedType, false)
leftType = checker.VisitExpressionWithForceType(
expression.Left,
expression,
expectedType,
false,
)

leftIsInvalid := leftType.IsInvalidType()

Expand Down Expand Up @@ -123,7 +128,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
expectedType = leftType
}

rightType = checker.VisitExpressionWithForceType(expression.Right, expectedType, false)
rightType = checker.VisitExpressionWithForceType(
expression.Right,
expression,
expectedType,
false,
)

rightIsInvalid := rightType.IsInvalidType()

Expand Down Expand Up @@ -174,7 +184,12 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
expectedType = optionalLeftType.Type
}
}
return checker.VisitExpressionWithForceType(expression.Right, expectedType, false)
return checker.VisitExpressionWithForceType(
expression.Right,
expression,
expectedType,
false,
)
})

rightIsInvalid := rightType.IsInvalidType()
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_casting_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (checker *Checker) VisitCastingExpression(expression *ast.CastingExpression

beforeErrors := len(checker.errors)

leftHandType, exprActualType := checker.visitExpression(leftHandExpression, expectedType)
leftHandType, exprActualType := checker.visitExpression(leftHandExpression, expression, expectedType)

hasErrors := len(checker.errors) > beforeErrors

Expand Down
12 changes: 10 additions & 2 deletions runtime/sema/check_composite_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,7 @@ func (checker *Checker) checkDefaultDestroyParamExpressionKind(

func (checker *Checker) checkDefaultDestroyEventParam(
param Parameter,
eventDeclaration ast.CompositeLikeDeclaration,
astParam *ast.Parameter,
containerType EntitlementSupportingType,
containerDeclaration ast.Declaration,
Expand All @@ -2113,7 +2114,8 @@ func (checker *Checker) checkDefaultDestroyEventParam(
compositeContainer,
compositeContainer.baseTypeDocString)
}
param.DefaultArgument = checker.VisitExpression(paramDefaultArgument, paramType)

param.DefaultArgument = checker.VisitExpression(paramDefaultArgument, eventDeclaration, paramType)

// default events must have default arguments for all their parameters; this is enforced in the parser
// we want to check that these arguments are all either literals or field accesses, and have primitive types
Expand Down Expand Up @@ -2143,7 +2145,13 @@ func (checker *Checker) checkDefaultDestroyEvent(
defer checker.leaveValueScope(eventDeclaration.EndPosition, true)

for index, param := range eventType.ConstructorParameters {
checker.checkDefaultDestroyEventParam(param, constructorFunctionParameters[index], containerType, containerDeclaration)
checker.checkDefaultDestroyEventParam(
param,
eventDeclaration,
constructorFunctionParameters[index],
containerType,
containerDeclaration,
)
}
}

Expand Down
8 changes: 4 additions & 4 deletions runtime/sema/check_conditional.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (checker *Checker) VisitIfStatement(statement *ast.IfStatement) (_ struct{}

switch test := statement.Test.(type) {
case ast.Expression:
checker.VisitExpression(test, BoolType)
checker.VisitExpression(test, statement, BoolType)

checker.checkConditionalBranches(
func() Type {
Expand Down Expand Up @@ -90,14 +90,14 @@ func (checker *Checker) VisitConditionalExpression(expression *ast.ConditionalEx

expectedType := checker.expectedType

checker.VisitExpression(expression.Test, BoolType)
checker.VisitExpression(expression.Test, expression, BoolType)

thenType, elseType := checker.checkConditionalBranches(
func() Type {
return checker.VisitExpression(expression.Then, expectedType)
return checker.VisitExpression(expression.Then, expression, expectedType)
},
func() Type {
return checker.VisitExpression(expression.Else, expectedType)
return checker.VisitExpression(expression.Else, expression, expectedType)
},
)

Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_conditions.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func (checker *Checker) checkCondition(condition ast.Condition) {
case *ast.TestCondition:

// check test expression is boolean
checker.VisitExpression(condition.Test, BoolType)
checker.VisitExpression(condition.Test, condition, BoolType)

// check message expression results in a string
if condition.Message != nil {
checker.VisitExpression(condition.Message, StringType)
checker.VisitExpression(condition.Message, condition, StringType)
}

case *ast.EmitCondition:
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_create_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (checker *Checker) VisitCreateExpression(expression *ast.CreateExpression)

invocation := expression.InvocationExpression

ty := checker.VisitExpression(invocation, nil)
ty := checker.VisitExpression(invocation, expression, nil)

if ty.IsInvalidType() {
return ty
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_destroy_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
func (checker *Checker) VisitDestroyExpression(expression *ast.DestroyExpression) (resultType Type) {
resultType = VoidType

valueType := checker.VisitExpression(expression.Expression, nil)
valueType := checker.VisitExpression(expression.Expression, expression, nil)

checker.ObserveImpureOperation(expression)
checker.recordResourceInvalidation(
Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_dictionary_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ func (checker *Checker) VisitDictionaryExpression(expression *ast.DictionaryExpr
// NOTE: important to check move after each type check,
// not combined after both type checks!

entryKeyType := checker.VisitExpression(entry.Key, keyType)
entryKeyType := checker.VisitExpression(entry.Key, expression, keyType)
checker.checkVariableMove(entry.Key)
checker.checkResourceMoveOperation(entry.Key, entryKeyType)

entryValueType := checker.VisitExpression(entry.Value, valueType)
entryValueType := checker.VisitExpression(entry.Value, expression, valueType)
checker.checkVariableMove(entry.Value)
checker.checkResourceMoveOperation(entry.Value, entryValueType)

Expand Down
5 changes: 3 additions & 2 deletions runtime/sema/check_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (checker *Checker) checkResourceVariableCapturingInFunction(variable *Varia
func (checker *Checker) VisitExpressionStatement(statement *ast.ExpressionStatement) (_ struct{}) {
expression := statement.Expression

ty := checker.VisitExpression(expression, nil)
ty := checker.VisitExpression(expression, statement, nil)

if ty.IsResourceType() {
checker.report(
Expand Down Expand Up @@ -270,7 +270,7 @@ func (checker *Checker) visitIndexExpression(
) Type {

targetExpression := indexExpression.TargetExpression
targetType := checker.VisitExpression(targetExpression, nil)
targetType := checker.VisitExpression(targetExpression, indexExpression, nil)

// NOTE: check indexed type first for UX reasons

Expand Down Expand Up @@ -309,6 +309,7 @@ func (checker *Checker) visitIndexExpression(
}
indexingType := checker.VisitExpression(
indexExpression.IndexingExpression,
indexExpression,
valueIndexedType.IndexingType(),
)

Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_for.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct
}
}

valueType := checker.VisitExpression(valueExpression, expectedType)
valueType := checker.VisitExpression(valueExpression, statement, expectedType)

// Only get the element type if the array is not a resource array.
// Otherwise, in addition to the `UnsupportedResourceForLoopError`,
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_force_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (checker *Checker) VisitForceExpression(expression *ast.ForceExpression) Ty
// i.e: if `x!` is `String`, then `x` is expected to be `String?`.
expectedType := wrapWithOptionalIfNotNil(checker.expectedType)

valueType := checker.VisitExpression(expression.Expression, expectedType)
valueType := checker.VisitExpression(expression.Expression, expression, expectedType)

if valueType.IsInvalidType() {
return valueType
Expand Down
16 changes: 8 additions & 8 deletions runtime/sema/check_invocation_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (checker *Checker) checkInvocationExpression(invocationExpression *ast.Invo
// check the invoked expression can be invoked

invokedExpression := invocationExpression.InvokedExpression
expressionType := checker.VisitExpression(invokedExpression, nil)
expressionType := checker.VisitExpression(invokedExpression, invocationExpression, nil)

// `inInvocation` should be reset before visiting arguments
checker.inInvocation = false
Expand Down Expand Up @@ -131,7 +131,7 @@ func (checker *Checker) checkInvocationExpression(invocationExpression *ast.Invo
argumentTypes = make([]Type, 0, argumentCount)

for _, argument := range invocationExpression.Arguments {
argumentType := checker.VisitExpression(argument.Expression, nil)
argumentType := checker.VisitExpression(argument.Expression, invocationExpression, nil)
argumentTypes = append(argumentTypes, argumentType)
}

Expand Down Expand Up @@ -469,7 +469,7 @@ func (checker *Checker) checkInvocation(

parameterTypes[argumentIndex] =
checker.checkInvocationRequiredArgument(
invocationExpression.Arguments,
invocationExpression,
argumentIndex,
functionType,
argumentTypes,
Expand All @@ -482,7 +482,7 @@ func (checker *Checker) checkInvocation(
for i := minCount; i < argumentCount; i++ {
argument := invocationExpression.Arguments[i]
// TODO: pass the expected type to support type inferring for parameters
argumentTypes[i] = checker.VisitExpression(argument.Expression, nil)
argumentTypes[i] = checker.VisitExpression(argument.Expression, invocationExpression, nil)
}
}

Expand Down Expand Up @@ -571,15 +571,15 @@ func (checker *Checker) checkTypeParameterInference(
}

func (checker *Checker) checkInvocationRequiredArgument(
arguments ast.Arguments,
invocationExpression *ast.InvocationExpression,
argumentIndex int,
functionType *FunctionType,
argumentTypes []Type,
typeParameters *TypeParameterTypeOrderedMap,
) (
parameterType Type,
) {
argument := arguments[argumentIndex]
argument := invocationExpression.Arguments[argumentIndex]

parameter := functionType.Parameters[argumentIndex]
parameterType = parameter.TypeAnnotation.Type
Expand Down Expand Up @@ -637,7 +637,7 @@ func (checker *Checker) checkInvocationRequiredArgument(
expectedType = nil
}

argumentType = checker.VisitExpression(argument.Expression, expectedType)
argumentType = checker.VisitExpression(argument.Expression, invocationExpression, expectedType)

// If we did not pass an expected type,
// we must manually check that the argument type and the parameter type are compatible.
Expand All @@ -659,7 +659,7 @@ func (checker *Checker) checkInvocationRequiredArgument(
// We will then have to manually check that the argument type is compatible
// with the parameter type (see below).

argumentType = checker.VisitExpression(argument.Expression, nil)
argumentType = checker.VisitExpression(argument.Expression, invocationExpression, nil)

// Try to unify the parameter type with the argument type.
// If unification fails, fall back to the parameter type for now.
Expand Down
21 changes: 13 additions & 8 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
// is an assignment, but the evaluation of the accessed exprssion itself (i.e. `a.b`)
// is not, so we temporarily clear the `inAssignment` status here before restoring it later.
accessedType = checker.withAssignment(false, func() Type {
return checker.VisitExpression(accessedExpression, nil)
return checker.VisitExpression(accessedExpression, expression, nil)
})
}()

Expand Down Expand Up @@ -345,16 +345,21 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
//
// This would result in a bound method for a resource, which is invalid.

if !checker.inInvocation &&
turbolent marked this conversation as resolved.
Show resolved Hide resolved
member.DeclarationKind == common.DeclarationKindFunction &&
if member.DeclarationKind == common.DeclarationKindFunction &&
!accessedType.IsInvalidType() &&
accessedType.IsResourceType() {

checker.report(
&ResourceMethodBindingError{
Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression),
},
)
parent := checker.parent
parentInvocationExpr, parentIsInvocation := parent.(*ast.InvocationExpression)

if !parentIsInvocation ||
expression != parentInvocationExpr.InvokedExpression {
checker.report(
&ResourceMethodBindingError{
Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression),
},
)
}
}

// If the member,
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_reference_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere

beforeErrors := len(checker.errors)

referencedType, actualType := checker.visitExpression(referencedExpression, expectedLeftType)
referencedType, actualType := checker.visitExpression(referencedExpression, referenceExpression, expectedLeftType)

// check that the type of the referenced value is not itself a reference
var requireNoReferenceNesting func(actualType Type)
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_remove_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (checker *Checker) VisitRemoveStatement(statement *ast.RemoveStatement) (_
}

nominalType := checker.convertNominalType(statement.Attachment)
base := checker.VisitExpression(statement.Value, nil)
base := checker.VisitExpression(statement.Value, statement, nil)
checker.checkUnusedExpressionResourceLoss(base, statement.Value)

if nominalType == InvalidType {
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_return_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (checker *Checker) VisitReturnStatement(statement *ast.ReturnStatement) (_
// If the return statement has a return value,
// check that the value's type matches the enclosing function's return type

valueType := checker.VisitExpression(statement.Expression, returnType)
valueType := checker.VisitExpression(statement.Expression, statement, returnType)

checker.Elaboration.SetReturnStatementTypes(
statement,
Expand Down
4 changes: 2 additions & 2 deletions runtime/sema/check_swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) (_ struct{})

// Then re-visit the same expressions, this time treat them as the value-expr of the assignment.
// The 'expected type' of the two expression would be the types obtained from the previous visit, swapped.
leftValueType := checker.VisitExpression(swap.Left, rightTargetType)
rightValueType := checker.VisitExpression(swap.Right, leftTargetType)
leftValueType := checker.VisitExpression(swap.Left, swap, rightTargetType)
rightValueType := checker.VisitExpression(swap.Right, swap, leftTargetType)

checker.Elaboration.SetSwapStatementTypes(
swap,
Expand Down
Loading
Loading