Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.2.0] Support key callback in Starlark min/max builtins #21960

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 144 additions & 29 deletions src/main/java/net/starlark/java/eval/MethodLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@

package net.starlark.java.eval;

import static com.google.common.collect.Streams.stream;
import static java.util.Comparator.comparing;

import com.google.common.base.Ascii;
import com.google.common.base.Throwables;
import com.google.common.collect.Ordering;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Optional;
import net.starlark.java.annot.Param;
import net.starlark.java.annot.ParamType;
import net.starlark.java.annot.StarlarkBuiltin;
Expand All @@ -31,46 +36,151 @@ class MethodLibrary {
@StarlarkMethod(
name = "min",
doc =
"Returns the smallest one of all given arguments. "
+ "If only one argument is provided, it must be a non-empty iterable. "
+ "It is an error if elements are not comparable (for example int with string), "
+ "or if no arguments are given. "
+ "<pre class=\"language-python\">min(2, 5, 4) == 2\n"
+ "min([5, 6, 3]) == 3</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
public Object min(Sequence<?> args) throws EvalException {
return findExtreme(args, Starlark.ORDERING.reverse());
"Returns the smallest one of all given arguments. If only one positional argument is"
+ " provided, it must be a non-empty iterable. It is an error if elements are not"
+ " comparable (for example int with string), or if no arguments are given."
+ "<pre class=\"language-python\">\n" //
+ "min(2, 5, 4) == 2\n"
+ "min([5, 6, 3]) == 3\n"
+ "min(\"six\", \"three\", \"four\", key = len) == \"six\" # the shortest\n"
+ "min([2, -2, -1, 1], key = abs) == -1 # the first encountered with minimal key"
+ " value\n"
+ "</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
parameters = {
@Param(
name = "key",
named = true,
positional = false,
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None")
},
useStarlarkThread = true)
public Object min(Object key, Sequence<?> args, StarlarkThread thread)
throws EvalException, InterruptedException {
return findExtreme(
args,
Starlark.toJavaOptional(key, StarlarkCallable.class),
Starlark.ORDERING.reverse(),
thread);
}

@StarlarkMethod(
name = "max",
doc =
"Returns the largest one of all given arguments. "
+ "If only one argument is provided, it must be a non-empty iterable."
+ "It is an error if elements are not comparable (for example int with string), "
+ "or if no arguments are given. "
+ "<pre class=\"language-python\">max(2, 5, 4) == 5\n"
+ "max([5, 6, 3]) == 6</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
public Object max(Sequence<?> args) throws EvalException {
return findExtreme(args, Starlark.ORDERING);
"Returns the largest one of all given arguments. If only one positional argument is"
+ " provided, it must be a non-empty iterable.It is an error if elements are not"
+ " comparable (for example int with string), or if no arguments are given."
+ "<pre class=\"language-python\">\n" //
+ "max(2, 5, 4) == 5\n"
+ "max([5, 6, 3]) == 6\n"
+ "max(\"two\", \"three\", \"four\", key = len) ==\"three\" # the longest\n"
+ "max([1, -1, -2, 2], key = abs) == -2 # the first encountered with maximal key"
+ " value\n"
+ "</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
parameters = {
@Param(
name = "key",
named = true,
positional = false,
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None")
},
useStarlarkThread = true)
public Object max(Object key, Sequence<?> args, StarlarkThread thread)
throws EvalException, InterruptedException {
return findExtreme(
args, Starlark.toJavaOptional(key, StarlarkCallable.class), Starlark.ORDERING, thread);
}

/** Returns the maximum element from this list, as determined by maxOrdering. */
private static Object findExtreme(Sequence<?> args, Ordering<Object> maxOrdering)
throws EvalException {
private static Object findExtreme(
Sequence<?> args,
Optional<StarlarkCallable> keyFn,
Ordering<Object> maxOrdering,
StarlarkThread thread)
throws EvalException, InterruptedException {
// Args can either be a list of items to compare, or a singleton list whose element is an
// iterable of items to compare. In either case, there must be at least one item to compare.
Iterable<?> items = (args.size() == 1) ? Starlark.toIterable(args.get(0)) : args;
try {
return maxOrdering.max(items);
if (keyFn.isPresent()) {
try {
return stream(items)
.map(value -> ValueWithComparisonKey.make(value, keyFn.get(), thread))
.max(comparing(ValueWithComparisonKey::getComparisonKey, maxOrdering))
.get()
.getValue();
} catch (ValueWithComparisonKey.KeyCallException ex) {
Throwables.throwIfInstanceOf(ex.getCause(), EvalException.class);
Throwables.throwIfInstanceOf(ex.getCause(), InterruptedException.class);
throw new AssertionError("Got invalid ValueWithComparisonKey.KeyCallException", ex);
}
} else {
return maxOrdering.max(items);
}
} catch (ClassCastException ex) {
throw new EvalException(ex.getMessage()); // e.g. unsupported comparison: int <=> string
} catch (NoSuchElementException ex) {
throw new EvalException("expected at least one item", ex);
}
}

/**
* Original value decorated with its comparison key; storing the comparison key alongside the
* value ensures that we call the comparison key computation function only once per original value
* (which is important in case the function has side effects).
*/
private static final class ValueWithComparisonKey {
private final Object value;
private final Object comparisonKey;

private ValueWithComparisonKey(Object value, Object comparisonKey) {
this.value = value;
this.comparisonKey = comparisonKey;
}

/**
* @throws KeyCallException wrapping the exception thrown by the underlying {@link
* Starlark#fastcall} call if it threw.
*/
static ValueWithComparisonKey make(
Object value, StarlarkCallable keyFn, StarlarkThread thread) {
Object[] positional = {value};
Object[] named = {};
try {
return new ValueWithComparisonKey(
value, Starlark.fastcall(thread, keyFn, positional, named));
} catch (EvalException | InterruptedException ex) {
throw new KeyCallException(ex);
}
}

Object getValue() {
return value;
}

Object getComparisonKey() {
return comparisonKey;
}

/** An unchecked exception wrapping an exception thrown by {@link Starlark#fastcall}. */
private static final class KeyCallException extends RuntimeException {
KeyCallException(Exception cause) {
super(cause);
}
}
}

@StarlarkMethod(
name = "abs",
doc =
Expand Down Expand Up @@ -140,16 +250,24 @@ private static boolean hasElementWithBooleanValue(Object seq, boolean value)
+ " using x < y. The elements are sorted into ascending order, unless the reverse"
+ " argument is True, in which case the order is descending.\n"
+ " Sorting is stable: elements that compare equal retain their original relative"
+ " order.\n"
+ "<pre class=\"language-python\">sorted([3, 5, 4]) == [3, 4, 5]</pre>",
+ " order.\n" //
+ "<pre class=\"language-python\">\n" //
+ "sorted([3, 5, 4]) == [3, 4, 5]\n" //
+ "sorted([3, 5, 4], reverse = True) == [5, 4, 3]\n" //
+ "sorted([\"two\", \"three\", \"four\"], key = len) == [\"two\", \"four\","
+ " \"three\"] # sort by length\n" //
+ "</pre>",
parameters = {
@Param(name = "iterable", doc = "The iterable sequence to sort."),
@Param(
name = "key",
doc = "An optional function applied to each element before comparison.",
named = true,
defaultValue = "None",
positional = false),
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None"),
@Param(
name = "reverse",
doc = "Return results in descending order.",
Expand Down Expand Up @@ -177,9 +295,6 @@ public StarlarkList<?> sorted(
// The user provided a key function.
// We must call it exactly once per element, in order,
// so use the decorate/sort/undecorate pattern.
if (!(key instanceof StarlarkCallable)) {
throw Starlark.errorf("for key, got %s, want callable", Starlark.type(key));
}
StarlarkCallable keyfn = (StarlarkCallable) key;

// decorate
Expand Down
90 changes: 86 additions & 4 deletions src/test/java/net/starlark/java/eval/testdata/min_max.star
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ assert_eq(min([1, 2], [3]), [1, 2])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6]), [0, 6])
assert_eq(min([-1]), -1)
assert_eq(min([5, 2, 3]), 2)
assert_eq(min({1: 2, -1: 3}), -1)
assert_eq(min({2: None}), 2)
assert_eq(min({1: 2, -1: 3}), -1) # a single dict argument is treated as its sequence of keys
assert_eq(min({2: None}), 2) # a single dict argument is treated as its sequence of keys
assert_eq(min(-1, 2), -1)
assert_eq(min(5, 2, 3), 2)
assert_eq(min(1, 1, 1, 1, 1, 1), 1)
Expand All @@ -21,15 +21,62 @@ assert_fails(lambda: min([]), "expected at least one item")
assert_fails(lambda: min(1, "2", True), "unsupported comparison: int <=> string")
assert_fails(lambda: min([1, "2", True]), "unsupported comparison: int <=> string")

# min with key
assert_eq(min("aBcDeFXyZ".elems(), key = lambda s: s.upper()), "a")
assert_eq(min("test", "xyz", key = len), "xyz")
assert_eq(min([4, 5], [1], key = lambda x: x), [1])
assert_eq(min([1, 2], [3], key = lambda x: x), [1, 2])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x), [0, 6])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x[1]), [2, 4])
assert_eq(min([-1], key = lambda x: x), -1)
assert_eq(min([5, 2, 3], key = lambda x: x), 2)
assert_eq(min({1: 2, -1: 3}, key = lambda x: x), -1) # a single dict argument is treated as its sequence of keys
assert_eq(min({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
assert_eq(min(-1, 2, key = lambda x: x), -1)
assert_eq(min(5, 2, 3, key = lambda x: x), 2)
assert_eq(min(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
assert_eq(min([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
assert_fails(lambda: min(1, key = lambda x: x), "type 'int' is not iterable")
assert_fails(lambda: min(key = lambda x: x), "expected at least one item")
assert_fails(lambda: min([], key = lambda x: x), "expected at least one item")
assert_fails(lambda: min([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
assert_fails(lambda: min([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")

# verify min with key chooses first value with minimal key
assert_eq(min(1, -1, -2, 2, key = abs), 1)
assert_eq(min([1, -1, -2, 2], key = abs), 1)

# min with failing key
assert_fails(lambda: min(0, 1, 2, 3, 4, key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
assert_fails(lambda: min([0, 1, 2, 3, 4], key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")

# min with non-callable key
assert_fails(lambda: min(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
assert_fails(lambda: min([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")

# verify min with key invokes key callback exactly once per item
def make_counting_identity():
call_count = {}

def counting_identity(x):
call_count[x] = call_count.get(x, 0) + 1
return x

return counting_identity, call_count

min_counting_identity, min_call_count = make_counting_identity()
assert_eq(min("min".elems(), key = min_counting_identity), "i")
assert_eq(min_call_count, {"m": 1, "i": 1, "n": 1})

# max
assert_eq(max("abcdefxyz".elems()), "z")
assert_eq(max("test", "xyz"), "xyz")
assert_eq(max("test", "xyz"), "xyz")
assert_eq(max([1, 2], [5]), [5])
assert_eq(max([-1]), -1)
assert_eq(max([5, 2, 3]), 5)
assert_eq(max({1: 2, -1: 3}), 1)
assert_eq(max({2: None}), 2)
assert_eq(max({1: 2, -1: 3}), 1) # a single dict argument is treated as its sequence of keys
assert_eq(max({2: None}), 2) # a single dict argument is treated as its sequence of keys
assert_eq(max(-1, 2), 2)
assert_eq(max(5, 2, 3), 5)
assert_eq(max(1, 1, 1, 1, 1, 1), 1)
Expand All @@ -40,3 +87,38 @@ assert_fails(lambda: max(), "expected at least one item")
assert_fails(lambda: max([]), "expected at least one item")
assert_fails(lambda: max(1, "2", True), "unsupported comparison: int <=> string")
assert_fails(lambda: max([1, "2", True]), "unsupported comparison: int <=> string")

# max with key
assert_eq(max("aBcDeFXyZ".elems(), key = lambda s: s.lower()), "Z")
assert_eq(max("test", "xyz", key = len), "test")
assert_eq(max([1, 2], [5], key = lambda x: x), [5])
assert_eq(max([-1], key = lambda x: x), -1)
assert_eq(max([5, 2, 3], key = lambda x: x), 5)
assert_eq(max({1: 2, -1: 3}, key = lambda x: x), 1) # a single dict argument is treated as its sequence of keys
assert_eq(max({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
assert_eq(max(-1, 2, key = lambda x: x), 2)
assert_eq(max(5, 2, 3, key = lambda x: x), 5)
assert_eq(max(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
assert_eq(max([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
assert_fails(lambda: max(1, key = lambda x: x), "type 'int' is not iterable")
assert_fails(lambda: max(key = lambda x: x), "expected at least one item")
assert_fails(lambda: max([], key = lambda x: x), "expected at least one item")
assert_fails(lambda: max([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
assert_fails(lambda: max([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")

# verify max with key chooses first value with minimal key
assert_eq(max(1, -1, -2, 2, key = abs), -2)
assert_eq(max([1, -1, -2, 2], key = abs), -2)

# max with failing key
assert_fails(lambda: max(0, 1, 2, 3, 4, key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
assert_fails(lambda: max([0, 1, 2, 3, 4], key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")

# max with non-callable key
assert_fails(lambda: max(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
assert_fails(lambda: max([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")

# verify max with key invokes key callback exactly once per item
max_counting_identity, max_call_count = make_counting_identity()
assert_eq(max("max".elems(), key = max_counting_identity), "x")
assert_eq(max_call_count, {"m": 1, "a": 1, "x": 1})
Loading
Loading