Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PythonCall annotation #684

Merged
merged 9 commits into from
Oct 23, 2023
104 changes: 61 additions & 43 deletions src/cli/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
isSdsEnumVariant,
isSdsExpressionLambda,
isSdsExpressionStatement,
isSdsFunction,
isSdsIndexedAccess,
isSdsInfixOperation,
isSdsList,
Expand Down Expand Up @@ -59,12 +60,12 @@ import { NodeFileSystem } from 'langium/node';
import {
getAbstractResults,
getAssignees,
streamBlockLambdaResults,
getImportedDeclarations,
getImports,
isRequiredParameter,
getModuleMembers,
getStatements,
isRequiredParameter,
streamBlockLambdaResults,
} from '../language/helpers/nodeProperties.js';
import { IdManager } from '../language/helpers/idManager.js';
import { isInStubFile } from '../language/helpers/fileExtensions.js';
Expand Down Expand Up @@ -209,7 +210,7 @@ const generateParameter = function (
frame: GenerationInfoFrame,
defaultValue: boolean = true,
): string {
return expandToString`${getPythonNameOrDefault(frame.getServices(), parameter)}${
return expandToString`${getPythonNameOrDefault(frame.services, parameter)}${
defaultValue && parameter.defaultValue !== undefined
? '=' + generateExpression(parameter.defaultValue, frame)
: ''
Expand Down Expand Up @@ -291,7 +292,7 @@ const generateStatement = function (statement: SdsStatement, frame: GenerationIn

const generateAssignment = function (assignment: SdsAssignment, frame: GenerationInfoFrame): string {
const requiredAssignees = isSdsCall(assignment.expression)
? getAbstractResults(frame.getServices().helpers.NodeMapper.callToCallable(assignment.expression)).length
? getAbstractResults(frame.services.helpers.NodeMapper.callToCallable(assignment.expression)).length
: /* c8 ignore next */
1;
const assignees = getAssignees(assignment);
Expand Down Expand Up @@ -347,7 +348,7 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
}
}

const partiallyEvaluatedNode = frame.getServices().evaluation.PartialEvaluator.evaluate(expression);
const partiallyEvaluatedNode = frame.services.evaluation.PartialEvaluator.evaluate(expression);
if (partiallyEvaluatedNode instanceof BooleanConstant) {
return partiallyEvaluatedNode.value ? 'True' : 'False';
} else if (partiallyEvaluatedNode instanceof IntConstant) {
Expand All @@ -360,39 +361,44 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
} else if (partiallyEvaluatedNode instanceof StringConstant) {
return `'${formatStringSingleLine(partiallyEvaluatedNode.value)}'`;
}
// Handled after constant expressions: EnumVariant, List, Map

if (isSdsTemplateString(expression)) {
// Handled after constant expressions: EnumVariant, List, Map
else if (isSdsTemplateString(expression)) {
return `f'${expression.expressions.map((expr) => generateExpression(expr, frame)).join('')}'`;
}

if (isSdsMap(expression)) {
} else if (isSdsMap(expression)) {
const mapContent = expression.entries.map(
(entry) => `${generateExpression(entry.key, frame)}: ${generateExpression(entry.value, frame)}`,
);
return `{${mapContent.join(', ')}}`;
}
if (isSdsList(expression)) {
} else if (isSdsList(expression)) {
const listContent = expression.elements.map((value) => generateExpression(value, frame));
return `[${listContent.join(', ')}]`;
}

if (isSdsBlockLambda(expression)) {
} else if (isSdsBlockLambda(expression)) {
return frame.getUniqueLambdaBlockName(expression);
}
if (isSdsCall(expression)) {
const sortedArgs = sortArguments(frame.getServices(), expression.argumentList.arguments);
} else if (isSdsCall(expression)) {
const callable = frame.services.helpers.NodeMapper.callToCallable(expression);
if (isSdsFunction(callable)) {
const pythonCall = frame.services.builtins.Annotations.getPythonCall(callable);
if (pythonCall) {
let thisParam: string | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = generateExpression(expression.receiver.receiver, frame);
}
const argumentsMap = getArgumentsMap(expression.argumentList.arguments, frame);
return generatePythonCall(pythonCall, argumentsMap, thisParam);
}
}

const sortedArgs = sortArguments(frame.services, expression.argumentList.arguments);
return expandToString`${generateExpression(expression.receiver, frame)}(${sortedArgs
.map((arg) => generateArgument(arg, frame))
.join(', ')})`;
}
if (isSdsExpressionLambda(expression)) {
} else if (isSdsExpressionLambda(expression)) {
return `lambda ${generateParameters(expression.parameterList, frame)}: ${generateExpression(
expression.result,
frame,
)}`;
}
if (isSdsInfixOperation(expression)) {
} else if (isSdsInfixOperation(expression)) {
const leftOperand = generateExpression(expression.leftOperand, frame);
const rightOperand = generateExpression(expression.rightOperand, frame);
switch (expression.operator) {
Expand All @@ -412,14 +418,12 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
default:
return `(${leftOperand}) ${expression.operator} (${rightOperand})`;
}
}
if (isSdsIndexedAccess(expression)) {
} else if (isSdsIndexedAccess(expression)) {
return expandToString`${generateExpression(expression.receiver, frame)}[${generateExpression(
expression.index,
frame,
)}]`;
}
if (isSdsMemberAccess(expression)) {
} else if (isSdsMemberAccess(expression)) {
const member = expression.member?.target.ref!;
const receiver = generateExpression(expression.receiver, frame);
if (isSdsEnumVariant(member)) {
Expand All @@ -442,31 +446,49 @@ const generateExpression = function (expression: SdsExpression, frame: Generatio
return `${receiver}.${memberExpression}`;
}
}
}
if (isSdsParenthesizedExpression(expression)) {
} else if (isSdsParenthesizedExpression(expression)) {
return expandToString`${generateExpression(expression.expression, frame)}`;
}
if (isSdsPrefixOperation(expression)) {
} else if (isSdsPrefixOperation(expression)) {
const operand = generateExpression(expression.operand, frame);
switch (expression.operator) {
case 'not':
return expandToString`not (${operand})`;
case '-':
return expandToString`-(${operand})`;
}
}
if (isSdsReference(expression)) {
} else if (isSdsReference(expression)) {
const declaration = expression.target.ref!;
const referenceImport =
getExternalReferenceNeededImport(frame.getServices(), expression, declaration) ||
getInternalReferenceNeededImport(frame.getServices(), expression, declaration);
getExternalReferenceNeededImport(frame.services, expression, declaration) ||
getInternalReferenceNeededImport(frame.services, expression, declaration);
frame.addImport(referenceImport);
return referenceImport?.alias || getPythonNameOrDefault(frame.getServices(), declaration);
return referenceImport?.alias || getPythonNameOrDefault(frame.services, declaration);
}
/* c8 ignore next 2 */
throw new Error(`Unknown expression type: ${expression.$type}`);
};

const generatePythonCall = function (
pythonCall: string,
argumentsMap: Map<string, string>,
thisParam: string | undefined = undefined,
): string {
if (thisParam) {
argumentsMap.set('this', thisParam);
}

return pythonCall.replace(/\$[_a-zA-Z][_a-zA-Z0-9]*/gu, (value) => argumentsMap.get(value.substring(1))!);
};

const getArgumentsMap = function (argumentList: SdsArgument[], frame: GenerationInfoFrame): Map<string, string> {
const argumentsMap = new Map<string, string>();
argumentList.reduce((map, value) => {
map.set(frame.services.helpers.NodeMapper.argumentToParameter(value)?.name!, generateArgument(value, frame));
return map;
}, argumentsMap);
return argumentsMap;
};

const sortArguments = function (services: SafeDsServices, argumentList: SdsArgument[]): SdsArgument[] {
// $containerIndex contains the index of the parameter in the receivers parameter list
const parameters = argumentList.map((argument) => {
Expand All @@ -482,7 +504,7 @@ const sortArguments = function (services: SafeDsServices, argumentList: SdsArgum
};

const generateArgument = function (argument: SdsArgument, frame: GenerationInfoFrame) {
const parameter = frame.getServices().helpers.NodeMapper.argumentToParameter(argument);
const parameter = frame.services.helpers.NodeMapper.argumentToParameter(argument);
return expandToString`${
parameter !== undefined && !isRequiredParameter(parameter)
? generateParameter(parameter, frame, false) + '='
Expand Down Expand Up @@ -567,9 +589,9 @@ interface ImportData {
}

class GenerationInfoFrame {
services: SafeDsServices;
blockLambdaManager: IdManager<SdsBlockLambda>;
importSet: Map<String, ImportData>;
readonly services: SafeDsServices;
private readonly blockLambdaManager: IdManager<SdsBlockLambda>;
private readonly importSet: Map<String, ImportData>;

constructor(services: SafeDsServices, importSet: Map<String, ImportData> = new Map<String, ImportData>()) {
this.services = services;
Expand All @@ -589,10 +611,6 @@ class GenerationInfoFrame {
getUniqueLambdaBlockName(lambda: SdsBlockLambda): string {
return `${BLOCK_LAMBDA_PREFIX}${this.blockLambdaManager.assignId(lambda)}`;
}

getServices(): SafeDsServices {
return this.services;
}
}

export interface GenerateOptions {
Expand Down
14 changes: 14 additions & 0 deletions src/language/builtins/safe-ds-annotations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
SdsAnnotatedObject,
SdsAnnotation,
SdsEnumVariant,
SdsFunction,
SdsModule,
SdsParameter,
} from '../generated/ast.js';
Expand Down Expand Up @@ -65,6 +66,19 @@ export class SafeDsAnnotations extends SafeDsModuleMembers<SdsAnnotation> {
return this.getAnnotation(IDE_INTEGRATION_URI, 'Expert');
}

getPythonCall(node: SdsFunction | undefined): string | undefined {
const value = this.getArgumentValue(node, this.PythonCall, 'callSpecification');
if (value instanceof StringConstant) {
return value.value;
} else {
return undefined;
}
}

get PythonCall(): SdsAnnotation | undefined {
return this.getAnnotation(CODE_GENERATION_URI, 'PythonCall');
}

getPythonModule(node: SdsModule | undefined): string | undefined {
const value = this.getArgumentValue(node, this.PythonModule, 'qualifiedName');
if (value instanceof StringConstant) {
Expand Down
2 changes: 1 addition & 1 deletion src/language/validation/names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ export const moduleMemberMustHaveNameThatIsUniqueInPackage = (services: SafeDsSe
let declarationsInPackage: AstNodeDescription[];
let kind: string;
if (packageName.startsWith(BUILTINS_ROOT_PACKAGE)) {
// For a builtin package the simple names of declarations must be unique
// For a builtin package, the simple names of declarations must be unique
declarationsInPackage = packageManager.getDeclarationsInPackageOrSubpackage(BUILTINS_ROOT_PACKAGE);
kind = 'builtin declarations';
} else {
Expand Down
14 changes: 14 additions & 0 deletions src/resources/builtins/safeds/lang/codeGeneration.sdsstub
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package safeds.lang

/**
* The specification of a corresponding function call in Python. By default, the function is called as specified in the
* stub.
*
* @param callSpecification
* The specification of corresponding Python call. The specification can contain template expression, which are
* replaced by the corresponding arguments of the function call. `$this` is replaced by the receiver of the call.
* `$param` is replaced by the value of the parameter called `param`. Otherwise, the string is used as-is.
*/
@Target([AnnotationTarget.Function])
annotation PythonCall(
callSpecification: String
)

/**
* The qualified name of the corresponding Python module. By default, this is the qualified name of the package.
*/
Expand Down
12 changes: 12 additions & 0 deletions tests/resources/generation/expressions/call/input.sdstest
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@ fun h(
@PythonName("param_2") param2: Int = 0
) -> result: Boolean

@PythonCall("$param.i()")
fun i(param: Any?)

@PythonCall("$param.j($param2)")
fun j(param: Any?, param2: Any?)

@PythonCall("k($param2, $param)")
fun k(param: Any?, param2: Any?)

pipeline test {
f((g(1, 2)));
f((g(param2 = 1, param1 = 2)));
f((h(1, 2)));
f((h(param2 = 1, param1 = 2)));
i("abc");
j("abc", 123);
k(1.23, 456);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ def test():
f(g(2, param2=1))
f(h(1, param_2=2))
f(h(2, param_2=1))
'abc'.i()
'abc'.j(123)
k(456, 1.23)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ fun h() -> (result1: Boolean, result2: Boolean)
class C() {
attr a: Int
@PythonName("c") attr b: Int

@PythonCall("$param.i($this)") fun i(param: Any?)
}

fun factory() -> instance: C?
Expand All @@ -21,4 +23,5 @@ pipeline test {
f(C().b);
f(factory()?.a);
f(factory()?.b);
f(C().i(1));
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ def test():
f(C().c)
f(safeds_runner.codegen.safe_access(factory(), 'a'))
f(safeds_runner.codegen.safe_access(factory(), 'c'))
f(1.i(C()))