Skip to content

Commit

Permalink
Add support for existential type (any) to AutoMockable.stencil (#1169)
Browse files Browse the repository at this point in the history
* Add support for existential type (any)  to AutoMockable.stencil

* Add tests

---------

Co-authored-by: Paul Bancarel <paul.bancarel.ext@adevinta.com>
  • Loading branch information
paul1893 and paulbancarelextadevinta authored Jul 27, 2023
1 parent 83c323f commit 6db4d41
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 21 deletions.
34 changes: 19 additions & 15 deletions Templates/Templates/AutoMockable.stencil
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import {{ import }}
{% macro closureReturnTypeName method %}{% if method.isOptionalReturnType %}{{ method.unwrappedReturnTypeName }}?{% else %}{{ method.returnTypeName }}{% endif %}{% endmacro %}

{% macro methodClosureDeclaration method %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call methodClosureName method %}: (({% for param in method.parameters %}{{ param.typeName }}{% if not forloop.last %}, {% endif %}{% endfor %}) {% if method.isAsync %}async {% endif %}{% if method.throws %}throws {% endif %}-> {% if method.isInitializer %}Void{% else %}{% call closureReturnTypeName method %}{% endif %})?
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call methodClosureName method %}: (({% for param in method.parameters %}{% call existentialClosureVariableTypeName param.typeName %}{% if not forloop.last %}, {% endif %}{% endfor %}) {% if method.isAsync %}async {% endif %}{% if method.throws %}throws {% endif %}-> {% if method.isInitializer %}Void{% else %}{% call closureReturnTypeName method %}{% endif %})?
{% endmacro %}

{% macro methodClosureCallParameters method %}{% for param in method.parameters %}{{ param.name }}{% if not forloop.last %}, {% endif %}{% endfor %}{% endmacro %}
Expand All @@ -77,11 +77,11 @@ import {{ import }}
{% endfor -%}
{% endset %}
{% if method.parameters.count == 1 and not hasNonEscapingClosures %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }}: {{ '(' if param.isClosure }}{{ param.typeName.unwrappedTypeName }}{{ ')' if param.isClosure }}?{% endfor %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedInvocations{% for param in method.parameters %}: [{{ '(' if param.isClosure }}{{ param.typeName.unwrappedTypeName }}{{ ')' if param.isClosure }}{%if param.typeName.isOptional%}?{%endif%}]{% endfor %} = []
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }}: {{ '(' if param.isClosure }}({% call existentialClosureVariableTypeName param.typeName.unwrappedTypeName %}{{ ')' if param.isClosure }})?{% endfor %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedInvocations{% for param in method.parameters %}: [{{ '(' if param.isClosure }}({% call existentialClosureVariableTypeName param.typeName.unwrappedTypeName %}){{ ')' if param.isClosure }}{%if param.typeName.isOptional%}?{%endif%}]{% endfor %} = []
{% elif not method.parameters.count == 0 and not hasNonEscapingClosures %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedArguments: ({% for param in method.parameters %}{{ param.name }}: {{ param.unwrappedTypeName if param.typeAttributes.escaping else param.typeName }}{{ ', ' if not forloop.last }}{% endfor %})?
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedInvocations: [({% for param in method.parameters %}{{ param.name }}: {{ param.unwrappedTypeName if param.typeAttributes.escaping else param.typeName }}{{ ', ' if not forloop.last }}{% endfor %})] = []
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedArguments: ({% for param in method.parameters %}{{ param.name }}: {% if param.typeAttributes.escaping %}{% call existentialClosureVariableTypeName param.typeName.unwrappedTypeName %}{% else %}{% call existentialClosureVariableTypeName param.typeName %}{% endif %}{{ ', ' if not forloop.last }}{% endfor %})?
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReceivedInvocations: [({% for param in method.parameters %}{{ param.name }}: {% if param.typeAttributes.escaping %}{% call existentialClosureVariableTypeName param.typeName.unwrappedTypeName %}{% else %}{% call existentialClosureVariableTypeName param.typeName %}{% endif %}{{ ', ' if not forloop.last }}{% endfor %})] = []
{% endif %}
{% if not method.returnTypeName.isVoid and not method.isInitializer %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}var {% call swiftifyMethodName method.selectorName %}ReturnValue: {{ '(' if method.returnTypeName.isClosure and not method.isOptionalReturnType }}{{ method.returnTypeName }}{{ ')' if method.returnTypeName.isClosure and not method.isOptionalReturnType }}{{ '!' if not method.isOptionalReturnType }}
Expand All @@ -99,7 +99,7 @@ import {{ import }}
{{ value }}
{% endfor %}
{% endfor %}
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}func {{ method.name }}{{ ' async' if method.isAsync }}{{ ' throws' if method.throws }}{% if not method.returnTypeName.isVoid %} -> {{ method.returnTypeName }}{% endif %} {
{% call accessLevel method.accessLevel %}{% call staticSpecifier method %}{% call methodName method %}{{ ' async' if method.isAsync }}{{ ' throws' if method.throws }}{% if not method.returnTypeName.isVoid %} -> {{ method.returnTypeName }}{% endif %} {
{% if method.throws %}
{% call methodThrowableErrorUsage method %}
{% endif %}
Expand Down Expand Up @@ -136,26 +136,26 @@ import {{ import }}
{% if method.throws %}
{% call swiftifyMethodName method.selectorName %}ThrowableError = nil
{% endif %}

{% endif %}

{% endmacro %}

{% macro mockOptionalVariable variable %}
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {{ variable.typeName }}
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {% call existentialVariableTypeName variable.typeName %}
{% endmacro %}

{% macro mockNonOptionalArrayOrDictionaryVariable variable %}
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {{ variable.typeName }} = {% if variable.isArray %}[]{% elif variable.isDictionary %}[:]{% endif %}
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {% call existentialVariableTypeName variable.typeName %} = {% if variable.isArray %}[]{% elif variable.isDictionary %}[:]{% endif %}
{% endmacro %}

{% macro mockNonOptionalVariable variable %}
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {{ variable.typeName }} {
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {% call existentialVariableTypeName variable.typeName %} {
get { return {% call underlyingMockedVariableName variable %} }
set(value) { {% call underlyingMockedVariableName variable %} = value }
}
{% set wrappedTypeName %}{% if variable.typeName.isProtocolComposition %}({{ variable.typeName }}){% else %}{{ variable.typeName }}{% endif %}{% endset %}
{% call accessLevel variable.readAccess %}var {% call underlyingMockedVariableName variable %}: {{ wrappedTypeName }}!
{% set wrappedTypeName %}{% if variable.typeName.isProtocolComposition %}({% call existentialVariableTypeName variable.typeName %}){% else %}{% call existentialVariableTypeName variable.typeName %}{% endif %}{% endset %}
{% call accessLevel variable.readAccess %}var {% call underlyingMockedVariableName variable %}: ({% call existentialVariableTypeName wrappedTypeName %})!
{% endmacro %}

{% macro variableThrowableErrorDeclaration variable %}
Expand All @@ -169,7 +169,7 @@ import {{ import }}
{% endmacro %}

{% macro variableClosureDeclaration variable %}
{% call accessLevel variable.readAccess %}var {% call variableClosureName variable %}: (() {% if variable.isAsync %}async {% endif %}{% if variable.throws %}throws {% endif %}-> {{ variable.typeName }})?
{% call accessLevel variable.readAccess %}var {% call variableClosureName variable %}: (() {% if variable.isAsync %}async {% endif %}{% if variable.throws %}throws {% endif %}-> {% call existentialVariableTypeName variable.typeName %})?
{% endmacro %}

{% macro variableClosureName variable %}{% call mockedVariableName variable %}Closure{% endmacro %}
Expand All @@ -180,7 +180,7 @@ import {{ import }}
return {% call mockedVariableName variable %}CallsCount > 0
}

{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {{ variable.typeName }} {
{% call accessLevel variable.readAccess %}var {% call mockedVariableName variable %}: {% call existentialVariableTypeName variable.typeName %} {
get {% if variable.isAsync %}async {% endif %}{% if variable.throws %}throws {% endif %}{
{% if variable.throws %}
{% call variableThrowableErrorUsage variable %}
Expand All @@ -193,7 +193,7 @@ import {{ import }}
}
}
}
{% call accessLevel variable.readAccess %}var {% call underlyingMockedVariableName variable %}: {{ variable.typeName }}{{ '!' if not variable.isOptional }}
{% call accessLevel variable.readAccess %}var {% call underlyingMockedVariableName variable %}: {% call existentialVariableTypeName variable.typeName %}{{ '!' if not variable.isOptional }}
{% if variable.throws %}
{% call variableThrowableErrorDeclaration variable %}
{% endif %}
Expand All @@ -202,6 +202,10 @@ import {{ import }}

{% macro underlyingMockedVariableName variable %}underlying{{ variable.name|upperFirstLetter }}{% endmacro %}
{% macro mockedVariableName variable %}{{ variable.name }}{% endmacro %}
{% macro existentialVariableTypeName typeName %}{% if typeName|contains:"any" and typeName|contains:"!" %}{{ typeName | replace:"any","(any" | replace:"!",")!" }}{% elif typeName|contains:"any" and typeName.isOptional %}{{ typeName | replace:"any","(any" | replace:"?",")?" }}{% elif typeName|contains:"any" and typeName.isClosure %}({{ typeName | replace:"any","(any" | replace:"?",")?" }}){%else%}{{ typeName }}{%endif%}{% endmacro %}
{% macro existentialClosureVariableTypeName typeName %}{% if typeName|contains:"any" and typeName|contains:"!" %}{{ typeName | replace:"any","(any" | replace:"!",")?" }}{% elif typeName|contains:"any" and typeName.isClosure and typeName|contains:"?" %}{{ typeName | replace:"any","(any" | replace:"?",")?" }}{% elif typeName|contains:"any" and typeName|contains:"?" %}{{ typeName | replace:"any","(any" | replace:"?",")?" }}{%else%}{{ typeName }}{%endif%}{% endmacro %}
{% macro existentialParameterTypeName typeName %}{% if typeName|contains:"any" and typeName|contains:"!" %}{{ typeName | replace:"any","(any" | replace:"!",")!" }}{% elif typeName|contains:"any" and typeName.isClosure and typeName|contains:"?" %}{{ typeName | replace:"any","(any" | replace:"?",")?" }}{% elif typeName|contains:"any" and typeName.isOptional %}{{ typeName | replace:"any","(any" | replace:"?",")?" }}{%else%}{{ typeName }}{%endif%}{% endmacro %}
{% macro methodName method %}func {{ method.shortName}}({%- for param in method.parameters %}{% if param.argumentLabel == nil %}_ {{ param.name }}{%elif param.argumentLabel == param.name%}{{ param.name }}{%else%}{{ param.argumentLabel }} {{ param.name }}{% endif %}: {% call existentialParameterTypeName param.typeName %}{% if not forloop.last %}, {% endif %}{% endfor -%}){% endmacro %}

{% for type in types.protocols where type.based.AutoMockable or type|annotated:"AutoMockable" %}{% if type.name != "AutoMockable" %}
{% call accessLevel type.accessLevel %}class {{ type.name }}Mock: {{ type.name }} {
Expand Down
31 changes: 27 additions & 4 deletions Templates/Tests/Context/AutoMockable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,36 @@ protocol FunctionWithAttributes: AutoMockable {
func callRepeatedAttributes() -> Bool
}

public protocol AccessLevelProtocol: AutoMockable {
protocol AccessLevelProtocol: AutoMockable {
var company: String? { get set }
var name: String { get }

func loadConfiguration() -> String?
}

protocol StaticMethodProtocol:AutoMockable {
static func staticFunction(String) -> String
}
protocol StaticMethodProtocol: AutoMockable {
static func staticFunction(_: String) -> String
}

protocol StubProtocol {}
protocol StubWithAnyNameProtocol {}

protocol AnyProtocol: AutoMockable {
var a: any StubProtocol { get }
var b: (any StubProtocol)? { get }
var c: (any StubProtocol)! { get }
var d: (((any StubProtocol)?) -> Void) { get }
var e: [(any StubProtocol)?] { get }
func f(_ x: (any StubProtocol)?, y: (any StubProtocol)!, z: any StubProtocol)
var g: any StubProtocol { get }
var h: (any StubProtocol)? { get }
var i: (any StubProtocol)! { get }
func j(x: (any StubProtocol)?, y: (any StubProtocol)!, z: any StubProtocol) async -> String
func k(x: ((any StubProtocol)?) -> Void, y: (any StubProtocol) -> Void)
func l(x: (((any StubProtocol)?) -> Void), y: ((any StubProtocol) -> Void))
var anyConfusingPropertyName: any StubProtocol { get }
func m(anyConfusingArgumentName: any StubProtocol)
func n(x: @escaping ((any StubProtocol)?) -> Void)
var o: any StubWithAnyNameProtocol { get }
func p(_ x: (any StubWithAnyNameProtocol)?)
}
156 changes: 154 additions & 2 deletions Templates/Tests/Expected/AutoMockable.expected
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,7 @@ class StaticMethodProtocolMock: StaticMethodProtocol {

static func staticFunction(_: String) -> String {
staticFunctionCallsCount += 1
staticFunctionReceived =
staticFunctionReceivedInvocations.append()
staticFunctionReceived = staticFunctionReceivedInvocations.append()
if let staticFunctionClosure = staticFunctionClosure {
return staticFunctionClosure()
} else {
Expand Down Expand Up @@ -792,3 +791,156 @@ class VariablesProtocolMock: VariablesProtocol {
var universityMarks: [String: Int] = [:]

}

class AnyProtocolMock: AnyProtocol {


var a: any StubProtocol {
get { return underlyingA }
set(value) { underlyingA = value }
}
var underlyingA: (any StubProtocol)!
var b: (any StubProtocol)?
var c: (any StubProtocol)!
var d: (((any StubProtocol)?) -> Void) {
get { return underlyingD }
set(value) { underlyingD = value }
}
var underlyingD: ((((any StubProtocol)?) -> Void))!
var e: [(any StubProtocol)?] = []
var g: any StubProtocol {
get { return underlyingG }
set(value) { underlyingG = value }
}
var underlyingG: (any StubProtocol)!
var h: (any StubProtocol)?
var i: (any StubProtocol)!
var anyConfusingPropertyName: any StubProtocol {
get { return underlyingAnyConfusingPropertyName }
set(value) { underlyingAnyConfusingPropertyName = value }
}
var underlyingAnyConfusingPropertyName: (any StubProtocol)!
var o: any StubWithAnyNameProtocol {
get { return underlyingO }
set(value) { underlyingO = value }
}
var underlyingO: (any StubWithAnyNameProtocol)!


//MARK: - f

var fyzCallsCount = 0
var fyzCalled: Bool {
return fyzCallsCount > 0
}
var fyzReceivedArguments: (x: (any StubProtocol)?, y: (any StubProtocol)?, z: any StubProtocol)?
var fyzReceivedInvocations: [(x: (any StubProtocol)?, y: (any StubProtocol)?, z: any StubProtocol)] = []
var fyzClosure: (((any StubProtocol)?, (any StubProtocol)?, any StubProtocol) -> Void)?

func f(_ x: (any StubProtocol)?, y: (any StubProtocol)!, z: any StubProtocol) {
fyzCallsCount += 1
fyzReceivedArguments = (x: x, y: y, z: z)
fyzReceivedInvocations.append((x: x, y: y, z: z))
fyzClosure?(x, y, z)
}

//MARK: - j

var jxyzCallsCount = 0
var jxyzCalled: Bool {
return jxyzCallsCount > 0
}
var jxyzReceivedArguments: (x: (any StubProtocol)?, y: (any StubProtocol)?, z: any StubProtocol)?
var jxyzReceivedInvocations: [(x: (any StubProtocol)?, y: (any StubProtocol)?, z: any StubProtocol)] = []
var jxyzReturnValue: String!
var jxyzClosure: (((any StubProtocol)?, (any StubProtocol)?, any StubProtocol) async -> String)?

func j(x: (any StubProtocol)?, y: (any StubProtocol)!, z: any StubProtocol) async -> String {
jxyzCallsCount += 1
jxyzReceivedArguments = (x: x, y: y, z: z)
jxyzReceivedInvocations.append((x: x, y: y, z: z))
if let jxyzClosure = jxyzClosure {
return await jxyzClosure(x, y, z)
} else {
return jxyzReturnValue
}
}

//MARK: - k

var kxyCallsCount = 0
var kxyCalled: Bool {
return kxyCallsCount > 0
}
var kxyClosure: ((((any StubProtocol)?) -> Void, (any StubProtocol) -> Void) -> Void)?

func k(x: ((any StubProtocol)?) -> Void, y: (any StubProtocol) -> Void) {
kxyCallsCount += 1
kxyClosure?(x, y)
}

//MARK: - l

var lxyCallsCount = 0
var lxyCalled: Bool {
return lxyCallsCount > 0
}
var lxyClosure: ((((any StubProtocol)?) -> Void, (any StubProtocol) -> Void) -> Void)?

func l(x: ((any StubProtocol)?) -> Void, y: (any StubProtocol) -> Void) {
lxyCallsCount += 1
lxyClosure?(x, y)
}

//MARK: - m

var mAnyConfusingArgumentNameCallsCount = 0
var mAnyConfusingArgumentNameCalled: Bool {
return mAnyConfusingArgumentNameCallsCount > 0
}
var mAnyConfusingArgumentNameReceivedAnyConfusingArgumentName: (any StubProtocol)?
var mAnyConfusingArgumentNameReceivedInvocations: [(any StubProtocol)] = []
var mAnyConfusingArgumentNameClosure: ((any StubProtocol) -> Void)?

func m(anyConfusingArgumentName: any StubProtocol) {
mAnyConfusingArgumentNameCallsCount += 1
mAnyConfusingArgumentNameReceivedAnyConfusingArgumentName = anyConfusingArgumentName
mAnyConfusingArgumentNameReceivedInvocations.append(anyConfusingArgumentName)
mAnyConfusingArgumentNameClosure?(anyConfusingArgumentName)
}

//MARK: - n

var nxCallsCount = 0
var nxCalled: Bool {
return nxCallsCount > 0
}
var nxReceivedX: ((((any StubProtocol)?) -> Void))?
var nxReceivedInvocations: [((((any StubProtocol)?) -> Void))] = []
var nxClosure: ((@escaping ((any StubProtocol)?) -> Void) -> Void)?

func n(x: @escaping ((any StubProtocol)?) -> Void) {
nxCallsCount += 1
nxReceivedX = x
nxReceivedInvocations.append(x)
nxClosure?(x)
}

//MARK: - p

var pCallsCount = 0
var pCalled: Bool {
return pCallsCount > 0
}
var pReceivedX: (any StubWithAnyNameProtocol)?
var pReceivedInvocations: [(any StubWithAnyNameProtocol)?] = []
var pClosure: (((any StubWithAnyNameProtocol)?) -> Void)?

func p(_ x: (any StubWithAnyNameProtocol)?) {
pCallsCount += 1
pReceivedX = x
pReceivedInvocations.append(x)
pClosure?(x)
}

}
Loading

0 comments on commit 6db4d41

Please sign in to comment.