Skip to content

Commit

Permalink
[7.2.0] Support key callback in Starlark min/max builtins (#21960)
Browse files Browse the repository at this point in the history
This is required by the language spec, but was not implemented in Bazel.

See https://github.com/bazelbuild/starlark/blob/master/spec.md#max

Fixes #15022

Also take the opportunity to adjust sorted's signature for `key` to
match.

RELNOTES: Starlark `min` and `max` buitins now allow a `key` callback,
similarly to `sorted`.
PiperOrigin-RevId: 623547043
Change-Id: I71d44aa715793f9f2260f9b20b876694154ff352

Commit
cf66672

Co-authored-by: Googler <arostovtsev@google.com>
  • Loading branch information
bazel-io and tetromino authored Apr 10, 2024
1 parent 40fa762 commit 4729529
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 67 deletions.
173 changes: 144 additions & 29 deletions src/main/java/net/starlark/java/eval/MethodLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@

package net.starlark.java.eval;

import static com.google.common.collect.Streams.stream;
import static java.util.Comparator.comparing;

import com.google.common.base.Ascii;
import com.google.common.base.Throwables;
import com.google.common.collect.Ordering;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Optional;
import net.starlark.java.annot.Param;
import net.starlark.java.annot.ParamType;
import net.starlark.java.annot.StarlarkBuiltin;
Expand All @@ -31,46 +36,151 @@ class MethodLibrary {
@StarlarkMethod(
name = "min",
doc =
"Returns the smallest one of all given arguments. "
+ "If only one argument is provided, it must be a non-empty iterable. "
+ "It is an error if elements are not comparable (for example int with string), "
+ "or if no arguments are given. "
+ "<pre class=\"language-python\">min(2, 5, 4) == 2\n"
+ "min([5, 6, 3]) == 3</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
public Object min(Sequence<?> args) throws EvalException {
return findExtreme(args, Starlark.ORDERING.reverse());
"Returns the smallest one of all given arguments. If only one positional argument is"
+ " provided, it must be a non-empty iterable. It is an error if elements are not"
+ " comparable (for example int with string), or if no arguments are given."
+ "<pre class=\"language-python\">\n" //
+ "min(2, 5, 4) == 2\n"
+ "min([5, 6, 3]) == 3\n"
+ "min(\"six\", \"three\", \"four\", key = len) == \"six\" # the shortest\n"
+ "min([2, -2, -1, 1], key = abs) == -1 # the first encountered with minimal key"
+ " value\n"
+ "</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
parameters = {
@Param(
name = "key",
named = true,
positional = false,
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None")
},
useStarlarkThread = true)
public Object min(Object key, Sequence<?> args, StarlarkThread thread)
throws EvalException, InterruptedException {
return findExtreme(
args,
Starlark.toJavaOptional(key, StarlarkCallable.class),
Starlark.ORDERING.reverse(),
thread);
}

@StarlarkMethod(
name = "max",
doc =
"Returns the largest one of all given arguments. "
+ "If only one argument is provided, it must be a non-empty iterable."
+ "It is an error if elements are not comparable (for example int with string), "
+ "or if no arguments are given. "
+ "<pre class=\"language-python\">max(2, 5, 4) == 5\n"
+ "max([5, 6, 3]) == 6</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
public Object max(Sequence<?> args) throws EvalException {
return findExtreme(args, Starlark.ORDERING);
"Returns the largest one of all given arguments. If only one positional argument is"
+ " provided, it must be a non-empty iterable.It is an error if elements are not"
+ " comparable (for example int with string), or if no arguments are given."
+ "<pre class=\"language-python\">\n" //
+ "max(2, 5, 4) == 5\n"
+ "max([5, 6, 3]) == 6\n"
+ "max(\"two\", \"three\", \"four\", key = len) ==\"three\" # the longest\n"
+ "max([1, -1, -2, 2], key = abs) == -2 # the first encountered with maximal key"
+ " value\n"
+ "</pre>",
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
parameters = {
@Param(
name = "key",
named = true,
positional = false,
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None")
},
useStarlarkThread = true)
public Object max(Object key, Sequence<?> args, StarlarkThread thread)
throws EvalException, InterruptedException {
return findExtreme(
args, Starlark.toJavaOptional(key, StarlarkCallable.class), Starlark.ORDERING, thread);
}

/** Returns the maximum element from this list, as determined by maxOrdering. */
private static Object findExtreme(Sequence<?> args, Ordering<Object> maxOrdering)
throws EvalException {
private static Object findExtreme(
Sequence<?> args,
Optional<StarlarkCallable> keyFn,
Ordering<Object> maxOrdering,
StarlarkThread thread)
throws EvalException, InterruptedException {
// Args can either be a list of items to compare, or a singleton list whose element is an
// iterable of items to compare. In either case, there must be at least one item to compare.
Iterable<?> items = (args.size() == 1) ? Starlark.toIterable(args.get(0)) : args;
try {
return maxOrdering.max(items);
if (keyFn.isPresent()) {
try {
return stream(items)
.map(value -> ValueWithComparisonKey.make(value, keyFn.get(), thread))
.max(comparing(ValueWithComparisonKey::getComparisonKey, maxOrdering))
.get()
.getValue();
} catch (ValueWithComparisonKey.KeyCallException ex) {
Throwables.throwIfInstanceOf(ex.getCause(), EvalException.class);
Throwables.throwIfInstanceOf(ex.getCause(), InterruptedException.class);
throw new AssertionError("Got invalid ValueWithComparisonKey.KeyCallException", ex);
}
} else {
return maxOrdering.max(items);
}
} catch (ClassCastException ex) {
throw new EvalException(ex.getMessage()); // e.g. unsupported comparison: int <=> string
} catch (NoSuchElementException ex) {
throw new EvalException("expected at least one item", ex);
}
}

/**
* Original value decorated with its comparison key; storing the comparison key alongside the
* value ensures that we call the comparison key computation function only once per original value
* (which is important in case the function has side effects).
*/
private static final class ValueWithComparisonKey {
private final Object value;
private final Object comparisonKey;

private ValueWithComparisonKey(Object value, Object comparisonKey) {
this.value = value;
this.comparisonKey = comparisonKey;
}

/**
* @throws KeyCallException wrapping the exception thrown by the underlying {@link
* Starlark#fastcall} call if it threw.
*/
static ValueWithComparisonKey make(
Object value, StarlarkCallable keyFn, StarlarkThread thread) {
Object[] positional = {value};
Object[] named = {};
try {
return new ValueWithComparisonKey(
value, Starlark.fastcall(thread, keyFn, positional, named));
} catch (EvalException | InterruptedException ex) {
throw new KeyCallException(ex);
}
}

Object getValue() {
return value;
}

Object getComparisonKey() {
return comparisonKey;
}

/** An unchecked exception wrapping an exception thrown by {@link Starlark#fastcall}. */
private static final class KeyCallException extends RuntimeException {
KeyCallException(Exception cause) {
super(cause);
}
}
}

@StarlarkMethod(
name = "abs",
doc =
Expand Down Expand Up @@ -140,16 +250,24 @@ private static boolean hasElementWithBooleanValue(Object seq, boolean value)
+ " using x < y. The elements are sorted into ascending order, unless the reverse"
+ " argument is True, in which case the order is descending.\n"
+ " Sorting is stable: elements that compare equal retain their original relative"
+ " order.\n"
+ "<pre class=\"language-python\">sorted([3, 5, 4]) == [3, 4, 5]</pre>",
+ " order.\n" //
+ "<pre class=\"language-python\">\n" //
+ "sorted([3, 5, 4]) == [3, 4, 5]\n" //
+ "sorted([3, 5, 4], reverse = True) == [5, 4, 3]\n" //
+ "sorted([\"two\", \"three\", \"four\"], key = len) == [\"two\", \"four\","
+ " \"three\"] # sort by length\n" //
+ "</pre>",
parameters = {
@Param(name = "iterable", doc = "The iterable sequence to sort."),
@Param(
name = "key",
doc = "An optional function applied to each element before comparison.",
named = true,
defaultValue = "None",
positional = false),
allowedTypes = {
@ParamType(type = StarlarkCallable.class),
@ParamType(type = NoneType.class),
},
doc = "An optional function applied to each element before comparison.",
defaultValue = "None"),
@Param(
name = "reverse",
doc = "Return results in descending order.",
Expand Down Expand Up @@ -177,9 +295,6 @@ public StarlarkList<?> sorted(
// The user provided a key function.
// We must call it exactly once per element, in order,
// so use the decorate/sort/undecorate pattern.
if (!(key instanceof StarlarkCallable)) {
throw Starlark.errorf("for key, got %s, want callable", Starlark.type(key));
}
StarlarkCallable keyfn = (StarlarkCallable) key;

// decorate
Expand Down
90 changes: 86 additions & 4 deletions src/test/java/net/starlark/java/eval/testdata/min_max.star
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ assert_eq(min([1, 2], [3]), [1, 2])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6]), [0, 6])
assert_eq(min([-1]), -1)
assert_eq(min([5, 2, 3]), 2)
assert_eq(min({1: 2, -1: 3}), -1)
assert_eq(min({2: None}), 2)
assert_eq(min({1: 2, -1: 3}), -1) # a single dict argument is treated as its sequence of keys
assert_eq(min({2: None}), 2) # a single dict argument is treated as its sequence of keys
assert_eq(min(-1, 2), -1)
assert_eq(min(5, 2, 3), 2)
assert_eq(min(1, 1, 1, 1, 1, 1), 1)
Expand All @@ -21,15 +21,62 @@ assert_fails(lambda: min([]), "expected at least one item")
assert_fails(lambda: min(1, "2", True), "unsupported comparison: int <=> string")
assert_fails(lambda: min([1, "2", True]), "unsupported comparison: int <=> string")

# min with key
assert_eq(min("aBcDeFXyZ".elems(), key = lambda s: s.upper()), "a")
assert_eq(min("test", "xyz", key = len), "xyz")
assert_eq(min([4, 5], [1], key = lambda x: x), [1])
assert_eq(min([1, 2], [3], key = lambda x: x), [1, 2])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x), [0, 6])
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x[1]), [2, 4])
assert_eq(min([-1], key = lambda x: x), -1)
assert_eq(min([5, 2, 3], key = lambda x: x), 2)
assert_eq(min({1: 2, -1: 3}, key = lambda x: x), -1) # a single dict argument is treated as its sequence of keys
assert_eq(min({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
assert_eq(min(-1, 2, key = lambda x: x), -1)
assert_eq(min(5, 2, 3, key = lambda x: x), 2)
assert_eq(min(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
assert_eq(min([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
assert_fails(lambda: min(1, key = lambda x: x), "type 'int' is not iterable")
assert_fails(lambda: min(key = lambda x: x), "expected at least one item")
assert_fails(lambda: min([], key = lambda x: x), "expected at least one item")
assert_fails(lambda: min([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
assert_fails(lambda: min([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")

# verify min with key chooses first value with minimal key
assert_eq(min(1, -1, -2, 2, key = abs), 1)
assert_eq(min([1, -1, -2, 2], key = abs), 1)

# min with failing key
assert_fails(lambda: min(0, 1, 2, 3, 4, key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
assert_fails(lambda: min([0, 1, 2, 3, 4], key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")

# min with non-callable key
assert_fails(lambda: min(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
assert_fails(lambda: min([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")

# verify min with key invokes key callback exactly once per item
def make_counting_identity():
call_count = {}

def counting_identity(x):
call_count[x] = call_count.get(x, 0) + 1
return x

return counting_identity, call_count

min_counting_identity, min_call_count = make_counting_identity()
assert_eq(min("min".elems(), key = min_counting_identity), "i")
assert_eq(min_call_count, {"m": 1, "i": 1, "n": 1})

# max
assert_eq(max("abcdefxyz".elems()), "z")
assert_eq(max("test", "xyz"), "xyz")
assert_eq(max("test", "xyz"), "xyz")
assert_eq(max([1, 2], [5]), [5])
assert_eq(max([-1]), -1)
assert_eq(max([5, 2, 3]), 5)
assert_eq(max({1: 2, -1: 3}), 1)
assert_eq(max({2: None}), 2)
assert_eq(max({1: 2, -1: 3}), 1) # a single dict argument is treated as its sequence of keys
assert_eq(max({2: None}), 2) # a single dict argument is treated as its sequence of keys
assert_eq(max(-1, 2), 2)
assert_eq(max(5, 2, 3), 5)
assert_eq(max(1, 1, 1, 1, 1, 1), 1)
Expand All @@ -40,3 +87,38 @@ assert_fails(lambda: max(), "expected at least one item")
assert_fails(lambda: max([]), "expected at least one item")
assert_fails(lambda: max(1, "2", True), "unsupported comparison: int <=> string")
assert_fails(lambda: max([1, "2", True]), "unsupported comparison: int <=> string")

# max with key
assert_eq(max("aBcDeFXyZ".elems(), key = lambda s: s.lower()), "Z")
assert_eq(max("test", "xyz", key = len), "test")
assert_eq(max([1, 2], [5], key = lambda x: x), [5])
assert_eq(max([-1], key = lambda x: x), -1)
assert_eq(max([5, 2, 3], key = lambda x: x), 5)
assert_eq(max({1: 2, -1: 3}, key = lambda x: x), 1) # a single dict argument is treated as its sequence of keys
assert_eq(max({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
assert_eq(max(-1, 2, key = lambda x: x), 2)
assert_eq(max(5, 2, 3, key = lambda x: x), 5)
assert_eq(max(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
assert_eq(max([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
assert_fails(lambda: max(1, key = lambda x: x), "type 'int' is not iterable")
assert_fails(lambda: max(key = lambda x: x), "expected at least one item")
assert_fails(lambda: max([], key = lambda x: x), "expected at least one item")
assert_fails(lambda: max([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
assert_fails(lambda: max([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")

# verify max with key chooses first value with minimal key
assert_eq(max(1, -1, -2, 2, key = abs), -2)
assert_eq(max([1, -1, -2, 2], key = abs), -2)

# max with failing key
assert_fails(lambda: max(0, 1, 2, 3, 4, key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
assert_fails(lambda: max([0, 1, 2, 3, 4], key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")

# max with non-callable key
assert_fails(lambda: max(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
assert_fails(lambda: max([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")

# verify max with key invokes key callback exactly once per item
max_counting_identity, max_call_count = make_counting_identity()
assert_eq(max("max".elems(), key = max_counting_identity), "x")
assert_eq(max_call_count, {"m": 1, "a": 1, "x": 1})
Loading

0 comments on commit 4729529

Please sign in to comment.