Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Figure Generator to allow nullable parameters #2687

Merged
merged 2 commits into from
Jul 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ private static class Key implements Comparable<Key> {
private final boolean isStatic;
private final boolean isPublic;


public Key(final Method m) {
this.name = m.getName();
this.isStatic = Modifier.isStatic(m.getModifiers());
Expand Down Expand Up @@ -370,24 +371,37 @@ private static class PyFunc implements Comparable<PyFunc> {
private final FunctionCallType functionCallType;
private final String[] javaFuncs;
private final String[] requiredParams;
private final String[] nullableParams;
private final String pydoc;
private final boolean generate;

public PyFunc(final String name, final FunctionCallType functionCallType, final String[] javaFuncs,
final String[] requiredParams, final String pydoc, final boolean generate) {
final String[] requiredParams, final String[] nullableParams, final String pydoc,
final boolean generate) {
this.name = name;
this.functionCallType = functionCallType;
this.javaFuncs = javaFuncs;
this.requiredParams = requiredParams == null ? new String[] {} : requiredParams;
this.nullableParams = nullableParams == null ? new String[] {} : nullableParams;
this.pydoc = pydoc;
this.generate = generate;
}

public PyFunc(final String name, final FunctionCallType functionCallType, final String[] javaFuncs,
final String[] requiredParams, final String pydoc, final boolean generate) {
this(name, functionCallType, javaFuncs, requiredParams, null, pydoc, generate);
}

public PyFunc(final String name, final FunctionCallType functionCallType, final String[] javaFuncs,
final String[] requiredParams, final String pydoc) {
this(name, functionCallType, javaFuncs, requiredParams, pydoc, true);
}

public PyFunc(final String name, final FunctionCallType functionCallType, final String[] javaFuncs,
final String[] requiredParams, final String[] nullableParams, final String pydoc) {
this(name, functionCallType, javaFuncs, requiredParams, nullableParams, pydoc, true);
}

@Override
public String toString() {
return "PyFunc{" +
Expand Down Expand Up @@ -443,12 +457,22 @@ public Map<Key, ArrayList<JavaFunction>> getSignatures(final Map<Key, ArrayList<
* Is the parameter required for the function?
*
* @param parameter python parameter
* @return is the parameter requried for the function?
* @return is the parameter required for the function?
*/
public boolean isRequired(final PyArg parameter) {
return Arrays.asList(requiredParams).contains(parameter.name);
}

/**
* Is the parameter nullable for the function?
*
* @param parameter python parameter
* @return is the parameter nullable for the function?
*/
public boolean isNullable(final PyArg parameter) {
return Arrays.asList(nullableParams).contains(parameter.name);
}

/**
* Gets the valid Java method argument name combinations.
*
Expand Down Expand Up @@ -476,15 +500,30 @@ private static Collection<String[]> javaArgNames(final ArrayList<JavaFunction> s
* Gets the valid Python method argument name combinations.
*
* @param signatures java functions with the same name.
* @param pyArgMap possible python function arguments
* @return valid Java method argument name combinations.
*/
private static List<String[]> pyArgNames(final ArrayList<JavaFunction> signatures,
final Map<String, PyArg> pyArgMap) {
final Set<Set<String>> seen = new HashSet<>();
return pyArgNames(signatures, pyArgMap, new String[] {});
}

/**
* Gets the valid Python method argument name combinations.
*
* @param signatures java functions with the same name.
* @param pyArgMap possible python function arguments
* @param excludeArgs arguments to exclude from the output
* @return valid Java method argument name combinations.
*/
private static List<String[]> pyArgNames(final ArrayList<JavaFunction> signatures,
final Map<String, PyArg> pyArgMap, String[] excludeArgs) {
final Set<Set<String>> seen = new HashSet<>();
return javaArgNames(signatures)
.stream()
.map(an -> Arrays.stream(an).map(s -> pyArgMap.get(s).name).toArray(String[]::new))
.map(an -> Arrays.stream(an).map(s -> pyArgMap.get(s).name)
.filter(s -> !Arrays.stream(excludeArgs).anyMatch(ex -> ex.equals(s)))
.toArray(String[]::new))
.filter(an -> seen.add(new HashSet<>(Arrays.asList(an))))
.sorted((first, second) -> {
final int c1 = Integer.compare(first.length, second.length);
Expand Down Expand Up @@ -766,42 +805,51 @@ private void generatePyFuncCallSingleton(final StringBuilder sb,
for (final Map.Entry<Key, ArrayList<JavaFunction>> entry : signatures.entrySet()) {
final Key key = entry.getKey();
final ArrayList<JavaFunction> sigs = entry.getValue();
final List<String[]> argNames = pyArgNames(sigs, pyArgMap);
final List<String[]> argNameList = pyArgNames(sigs, pyArgMap);
final List<String[]> nonNullableArgNameList = pyArgNames(sigs, pyArgMap, nullableParams);

if (argNameList.size() != nonNullableArgNameList.size()) {
throw new RuntimeException(
"Full argument list size " + argNameList.size() + " and non-nullable list size "
+ nonNullableArgNameList.size() + " do not match for " + key);
}

for (final String[] an : argNames) {
validateArgNames(an, alreadyGenerated, signatures, pyArgMap);
final String[] quoted_an = Arrays.stream(an).map(s -> "\"" + s + "\"").toArray(String[]::new);
for (int i = 0; i < nonNullableArgNameList.size(); i++) {
final String[] argNames = argNameList.get(i);
final String[] nonNullableArgNames = nonNullableArgNameList.get(i);
validateArgNames(argNames, alreadyGenerated, signatures, pyArgMap);
final String[] quotedNonNullableArgNames =
Arrays.stream(nonNullableArgNames).map(s -> "\"" + s + "\"").toArray(String[]::new);
final boolean hasNullables = nonNullableArgNames.length < argNames.length;

if (quoted_an.length == 0) {
if (argNames.length == 0) {
sb.append(INDENT)
.append(INDENT)
.append(isFirst ? "if" : "elif")
.append(" not non_null_args:\n")
.append(INDENT)
.append(INDENT)
.append(" not non_null_args:\n");
} else if (hasNullables) {
sb.append(INDENT)
.append(INDENT)
.append("return Figure(self.j_figure.")
.append(key.name)
.append("(")
.append(String.join(", ", an))
.append("))\n");
.append(isFirst ? "if" : "elif")
.append(" set({")
.append(String.join(", ", quotedNonNullableArgNames))
.append("}).issubset(non_null_args):\n");
} else {
sb.append(INDENT)
.append(INDENT)
.append(isFirst ? "if" : "elif")
.append(" non_null_args == {")
.append(String.join(", ", quoted_an))
.append("}:\n")
.append(INDENT)
.append(INDENT)
.append(INDENT)
.append("return Figure(self.j_figure.")
.append(key.name)
.append("(")
.append(String.join(", ", an))
.append("))\n");

.append(String.join(", ", quotedNonNullableArgNames))
.append("}:\n");
}
sb.append(INDENT)
.append(INDENT)
.append(INDENT)
.append("return Figure(self.j_figure.")
.append(key.name)
.append("(")
.append(String.join(", ", argNames))
.append("))\n");
isFirst = false;
}
}
Expand All @@ -824,7 +872,6 @@ private void generatePyFuncCallSingleton(final StringBuilder sb,
*/
private void generatePyFuncCallSequential(final StringBuilder sb,
final Map<Key, ArrayList<JavaFunction>> signatures, final Map<String, PyArg> pyArgMap) {

sb.append(INDENT)
.append(INDENT)
.append("f_called = False\n")
Expand Down