Skip to content

Commit

Permalink
feat: PythonCall annotation (#684)
Browse files Browse the repository at this point in the history
Closes #617 

### Summary of Changes

- added `PythonCall` annotation to builtins
- added code generation for direct function calls and member access
function calls
- added tests

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people authored Oct 23, 2023
1 parent f23fa29 commit 15114df
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 44 deletions.
104 changes: 61 additions & 43 deletions src/cli/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
isSdsEnumVariant,
isSdsExpressionLambda,
isSdsExpressionStatement,
isSdsFunction,
isSdsIndexedAccess,
isSdsInfixOperation,
isSdsList,
Expand Down Expand Up @@ -59,12 +60,12 @@ import { NodeFileSystem } from 'langium/node';
import {
getAbstractResults,
getAssignees,
streamBlockLambdaResults,
getImportedDeclarations,
getImports,
isRequiredParameter,
getModuleMembers,
getStatements,
isRequiredParameter,
streamBlockLambdaResults,
} from '../language/helpers/nodeProperties.js';
import { IdManager } from '../language/helpers/idManager.js';
import { isInStubFile } from '../language/helpers/fileExtensions.js';
Expand Down Expand Up @@ -209,7 +210,7 @@ const generateParameter = function (
frame: GenerationInfoFrame,
defaultValue: boolean = true,
): string {
return expandToString`${getPythonNameOrDefault(frame.getServices(), parameter)}${
return expandToString`${getPythonNameOrDefault(frame.services, parameter)}${
defaultValue && parameter.defaultValue !== undefined
? '=' + generateExpression(parameter.defaultValue, frame)
: ''
Expand Down Expand Up @@ -291,7 +292,7 @@ const generateStatement = function (statement: SdsStatement, frame: GenerationIn

const generateAssignment = function (assignment: SdsAssignment, frame: GenerationInfoFrame): string {
const requiredAssignees = isSdsCall(assignment.expression)
? getAbstractResults(frame.getServices().helpers.NodeMapper.callToCallable(assignment.expression)).length
? getAbstractResults(frame.services.helpers.NodeMapper.callToCallable(assignment.expression)).length
: /* c8 ignore next */
1;
const assignees = getAssignees(assignment);
Expand Down Expand Up @@ -347,7 +348,7 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
}
}

const partiallyEvaluatedNode = frame.getServices().evaluation.PartialEvaluator.evaluate(expression);
const partiallyEvaluatedNode = frame.services.evaluation.PartialEvaluator.evaluate(expression);
if (partiallyEvaluatedNode instanceof BooleanConstant) {
return partiallyEvaluatedNode.value ? 'True' : 'False';
} else if (partiallyEvaluatedNode instanceof IntConstant) {
Expand All @@ -360,39 +361,44 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
} else if (partiallyEvaluatedNode instanceof StringConstant) {
return `'${formatStringSingleLine(partiallyEvaluatedNode.value)}'`;
}
// Handled after constant expressions: EnumVariant, List, Map

if (isSdsTemplateString(expression)) {
// Handled after constant expressions: EnumVariant, List, Map
else if (isSdsTemplateString(expression)) {
return `f'${expression.expressions.map((expr) => generateExpression(expr, frame)).join('')}'`;
}

if (isSdsMap(expression)) {
} else if (isSdsMap(expression)) {
const mapContent = expression.entries.map(
(entry) => `${generateExpression(entry.key, frame)}: ${generateExpression(entry.value, frame)}`,
);
return `{${mapContent.join(', ')}}`;
}
if (isSdsList(expression)) {
} else if (isSdsList(expression)) {
const listContent = expression.elements.map((value) => generateExpression(value, frame));
return `[${listContent.join(', ')}]`;
}

if (isSdsBlockLambda(expression)) {
} else if (isSdsBlockLambda(expression)) {
return frame.getUniqueLambdaBlockName(expression);
}
if (isSdsCall(expression)) {
const sortedArgs = sortArguments(frame.getServices(), expression.argumentList.arguments);
} else if (isSdsCall(expression)) {
const callable = frame.services.helpers.NodeMapper.callToCallable(expression);
if (isSdsFunction(callable)) {
const pythonCall = frame.services.builtins.Annotations.getPythonCall(callable);
if (pythonCall) {
let thisParam: string | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = generateExpression(expression.receiver.receiver, frame);
}
const argumentsMap = getArgumentsMap(expression.argumentList.arguments, frame);
return generatePythonCall(pythonCall, argumentsMap, thisParam);
}
}

const sortedArgs = sortArguments(frame.services, expression.argumentList.arguments);
return expandToString`${generateExpression(expression.receiver, frame)}(${sortedArgs
.map((arg) => generateArgument(arg, frame))
.join(', ')})`;
}
if (isSdsExpressionLambda(expression)) {
} else if (isSdsExpressionLambda(expression)) {
return `lambda ${generateParameters(expression.parameterList, frame)}: ${generateExpression(
expression.result,
frame,
)}`;
}
if (isSdsInfixOperation(expression)) {
} else if (isSdsInfixOperation(expression)) {
const leftOperand = generateExpression(expression.leftOperand, frame);
const rightOperand = generateExpression(expression.rightOperand, frame);
switch (expression.operator) {
Expand All @@ -412,14 +418,12 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
default:
return `(${leftOperand}) ${expression.operator} (${rightOperand})`;
}
}
if (isSdsIndexedAccess(expression)) {
} else if (isSdsIndexedAccess(expression)) {
return expandToString`${generateExpression(expression.receiver, frame)}[${generateExpression(
expression.index,
frame,
)}]`;
}
if (isSdsMemberAccess(expression)) {
} else if (isSdsMemberAccess(expression)) {
const member = expression.member?.target.ref!;
const receiver = generateExpression(expression.receiver, frame);
if (isSdsEnumVariant(member)) {
Expand All @@ -442,31 +446,49 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
return `${receiver}.${memberExpression}`;
}
}
}
if (isSdsParenthesizedExpression(expression)) {
} else if (isSdsParenthesizedExpression(expression)) {
return expandToString`${generateExpression(expression.expression, frame)}`;
}
if (isSdsPrefixOperation(expression)) {
} else if (isSdsPrefixOperation(expression)) {
const operand = generateExpression(expression.operand, frame);
switch (expression.operator) {
case 'not':
return expandToString`not (${operand})`;
case '-':
return expandToString`-(${operand})`;
}
}
if (isSdsReference(expression)) {
} else if (isSdsReference(expression)) {
const declaration = expression.target.ref!;
const referenceImport =
getExternalReferenceNeededImport(frame.getServices(), expression, declaration) ||
getInternalReferenceNeededImport(frame.getServices(), expression, declaration);
getExternalReferenceNeededImport(frame.services, expression, declaration) ||
getInternalReferenceNeededImport(frame.services, expression, declaration);
frame.addImport(referenceImport);
return referenceImport?.alias || getPythonNameOrDefault(frame.getServices(), declaration);
return referenceImport?.alias || getPythonNameOrDefault(frame.services, declaration);
}
/* c8 ignore next 2 */
throw new Error(`Unknown expression type: ${expression.$type}`);
};

const generatePythonCall = function (
pythonCall: string,
argumentsMap: Map<string, string>,
thisParam: string | undefined = undefined,
): string {
if (thisParam) {
argumentsMap.set('this', thisParam);
}

return pythonCall.replace(/\$[_a-zA-Z][_a-zA-Z0-9]*/gu, (value) => argumentsMap.get(value.substring(1))!);
};

const getArgumentsMap = function (argumentList: SdsArgument[], frame: GenerationInfoFrame): Map<string, string> {
const argumentsMap = new Map<string, string>();
argumentList.reduce((map, value) => {
map.set(frame.services.helpers.NodeMapper.argumentToParameter(value)?.name!, generateArgument(value, frame));
return map;
}, argumentsMap);
return argumentsMap;
};

const sortArguments = function (services: SafeDsServices, argumentList: SdsArgument[]): SdsArgument[] {
// $containerIndex contains the index of the parameter in the receivers parameter list
const parameters = argumentList.map((argument) => {
Expand All @@ -482,7 +504,7 @@ const sortArguments = function (services: SafeDsServices, argumentList: SdsArgum
};

const generateArgument = function (argument: SdsArgument, frame: GenerationInfoFrame) {
const parameter = frame.getServices().helpers.NodeMapper.argumentToParameter(argument);
const parameter = frame.services.helpers.NodeMapper.argumentToParameter(argument);
return expandToString`${
parameter !== undefined && !isRequiredParameter(parameter)
? generateParameter(parameter, frame, false) + '='
Expand Down Expand Up @@ -567,9 +589,9 @@ interface ImportData {
}

class GenerationInfoFrame {
services: SafeDsServices;
blockLambdaManager: IdManager<SdsBlockLambda>;
importSet: Map<String, ImportData>;
readonly services: SafeDsServices;
private readonly blockLambdaManager: IdManager<SdsBlockLambda>;
private readonly importSet: Map<String, ImportData>;

constructor(services: SafeDsServices, importSet: Map<String, ImportData> = new Map<String, ImportData>()) {
this.services = services;
Expand All @@ -589,10 +611,6 @@ class GenerationInfoFrame {
getUniqueLambdaBlockName(lambda: SdsBlockLambda): string {
return `${BLOCK_LAMBDA_PREFIX}${this.blockLambdaManager.assignId(lambda)}`;
}

getServices(): SafeDsServices {
return this.services;
}
}

export interface GenerateOptions {
Expand Down
14 changes: 14 additions & 0 deletions src/language/builtins/safe-ds-annotations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
SdsAnnotatedObject,
SdsAnnotation,
SdsEnumVariant,
SdsFunction,
SdsModule,
SdsParameter,
} from '../generated/ast.js';
Expand Down Expand Up @@ -65,6 +66,19 @@ export class SafeDsAnnotations extends SafeDsModuleMembers<SdsAnnotation> {
return this.getAnnotation(IDE_INTEGRATION_URI, 'Expert');
}

getPythonCall(node: SdsFunction | undefined): string | undefined {
const value = this.getArgumentValue(node, this.PythonCall, 'callSpecification');
if (value instanceof StringConstant) {
return value.value;
} else {
return undefined;
}
}

get PythonCall(): SdsAnnotation | undefined {
return this.getAnnotation(CODE_GENERATION_URI, 'PythonCall');
}

getPythonModule(node: SdsModule | undefined): string | undefined {
const value = this.getArgumentValue(node, this.PythonModule, 'qualifiedName');
if (value instanceof StringConstant) {
Expand Down
2 changes: 1 addition & 1 deletion src/language/validation/names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ export const moduleMemberMustHaveNameThatIsUniqueInPackage = (services: SafeDsSe
let declarationsInPackage: AstNodeDescription[];
let kind: string;
if (packageName.startsWith(BUILTINS_ROOT_PACKAGE)) {
// For a builtin package the simple names of declarations must be unique
// For a builtin package, the simple names of declarations must be unique
declarationsInPackage = packageManager.getDeclarationsInPackageOrSubpackage(BUILTINS_ROOT_PACKAGE);
kind = 'builtin declarations';
} else {
Expand Down
14 changes: 14 additions & 0 deletions src/resources/builtins/safeds/lang/codeGeneration.sdsstub
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package safeds.lang

/**
* The specification of a corresponding function call in Python. By default, the function is called as specified in the
* stub.
*
* @param callSpecification
* The specification of corresponding Python call. The specification can contain template expression, which are
* replaced by the corresponding arguments of the function call. `$this` is replaced by the receiver of the call.
* `$param` is replaced by the value of the parameter called `param`. Otherwise, the string is used as-is.
*/
@Target([AnnotationTarget.Function])
annotation PythonCall(
callSpecification: String
)

/**
* The qualified name of the corresponding Python module. By default, this is the qualified name of the package.
*/
Expand Down
12 changes: 12 additions & 0 deletions tests/resources/generation/expressions/call/input.sdstest
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@ fun h(
@PythonName("param_2") param2: Int = 0
) -> result: Boolean

@PythonCall("$param.i()")
fun i(param: Any?)

@PythonCall("$param.j($param2)")
fun j(param: Any?, param2: Any?)

@PythonCall("k($param2, $param)")
fun k(param: Any?, param2: Any?)

pipeline test {
f((g(1, 2)));
f((g(param2 = 1, param1 = 2)));
f((h(1, 2)));
f((h(param2 = 1, param1 = 2)));
i("abc");
j("abc", 123);
k(1.23, 456);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ def test():
f(g(2, param2=1))
f(h(1, param_2=2))
f(h(2, param_2=1))
'abc'.i()
'abc'.j(123)
k(456, 1.23)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ fun h() -> (result1: Boolean, result2: Boolean)
class C() {
attr a: Int
@PythonName("c") attr b: Int

@PythonCall("$param.i($this)") fun i(param: Any?)
}

fun factory() -> instance: C?
Expand All @@ -21,4 +23,5 @@ pipeline test {
f(C().b);
f(factory()?.a);
f(factory()?.b);
f(C().i(1));
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ def test():
f(C().c)
f(safeds_runner.codegen.safe_access(factory(), 'a'))
f(safeds_runner.codegen.safe_access(factory(), 'c'))
f(1.i(C()))

0 comments on commit 15114df

Please sign in to comment.