Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitive Comparator.comparing(keyExtractor), like standrad Comparator #313

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions drv/Comparator.drv
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package PACKAGE;

import java.util.Comparator;
import java.util.Objects;
import java.io.Serializable;

/** A type-specific {@link Comparator}; provides methods to compare two primitive types both as objects
* and as primitive types.
Expand Down Expand Up @@ -76,4 +78,112 @@ public interface KEY_COMPARATOR KEY_GENERIC extends Comparator<KEY_GENERIC_CLASS
return Comparator.super.thenComparing(second);
}
#endif

#define CONCAT_(A, B) A ## B
#define CONCAT(A, B) CONCAT_(A, B)
#define KEY_TO_OBJ_FUNCTION CONCAT(KEY_TYPE_CAP, 2ObjectFunction)
#define KEY_TO_INT_FUNCTION CONCAT(KEY_TYPE_CAP, 2IntFunction)
#define KEY_TO_LONG_FUNCTION CONCAT(KEY_TYPE_CAP, 2LongFunction)
#define KEY_TO_DOUBLE_FUNCTION CONCAT(KEY_TYPE_CAP, 2DoubleFunction)


/**
* Accepts a function that extracts a {@link java.lang.Comparable Comparable} sort key from
* a primitive key, and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function is also serializable.
*
* @param keyExtractor the function used to extract the {@link Comparable} sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
#if KEYS_PRIMITIVE
static <U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? extends U> keyExtractor) {
#else
static <K, U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? super K, ? extends U> keyExtractor) {
#endif
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> keyExtractor.get(k1).compareTo(keyExtractor.get(k2));
}

/**
* Accepts a function that extracts a sort key from a primitive key, and returns a
* comparator that compares by that sort key using the specified {@link Comparator}.
*
* <p>
* The returned comparator is serializable if the specified function and comparator are
* both serializable.
*
* @param keyExtractor the function used to extract the sort key
* @param keyComparator the {@code Comparator} used to compare the sort key
* @return a comparator that compares by an extracted key using the specified {@code Comparator}
* @throws NullPointerException if {@code keyExtractor} or {@code keyComparator} are {@code null}
*/
#if KEYS_PRIMITIVE
static <U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? extends U> keyExtractor, Comparator<? super U> keyComparator) {
#else
static <K, U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? super K, ? extends U> keyExtractor, Comparator<? super U> keyComparator) {
#endif
Objects.requireNonNull(keyExtractor);
Objects.requireNonNull(keyComparator);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> keyComparator.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code int} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the integer sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingInt(KEY_TO_INT_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Integer.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code long} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the long sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingLong(KEY_TO_LONG_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Long.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code double} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the double sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingDouble(KEY_TO_DOUBLE_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Double.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

}
69 changes: 69 additions & 0 deletions test/it/unimi/dsi/fastutil/ints/IntComparatorTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (C) 2003-2024 Barak Ugav and Sebastiano Vigna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package it.unimi.dsi.fastutil.ints;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

public class IntComparatorTest {

@Test
public void comparing() {
String[] array = new String[] { "68", "98", "30", "62", "81", "61", "80", "63", "62", "77", "10", "95", "40",
"73", "55", "45", "16", "10", "86", "28", "79", "44", "52", "92", "98", "28", "88", "70", "70", "10" };
IntComparator c = IntComparator.comparing(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 29) * 1337) % array.length;
assertEquals(c.compare(i, j), array[i].compareTo(array[j]));
}
}

@Test
public void comparingInt() {
int[] array = new int[] { 81, 87, 70, 54, 40, 79, 16, 8, 84, 39, 37, 84, 64, 60, 31, 44, 95, 15, 52, 48, 19, 20,
75, 31, 46, 61, 38, 27, 32, 84 };
IntComparator c = IntComparator.comparingInt(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 17) * 1337) % array.length;
assertEquals(c.compare(i, j), Integer.compare(array[i], array[j]));
}
}

@Test
public void comparingLong() {
long[] array = new long[] { 26, 49, 49, 24, 15, 71, 10, 88, 78, 4, 42, 79, 75, 69, 63, 16, 71, 47, 54, 39, 89,
10, 64, 37, 38, 59, 81, 59, 58, 33 };
IntComparator c = IntComparator.comparingLong(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 19) * 1337) % array.length;
assertEquals(c.compare(i, j), Long.compare(array[i], array[j]));
}
}

@Test
public void comparingDouble() {
double[] array = new double[] { 0.61, 0.97, 0.97, 0.75, 0.73, 0.36, 0.72, 0.14, 0.93, 0.18, 0.45, 0.03, 0.62,
0.05, 0.04, 0.05, 0.38, 0.89, 0., 0.93, 0.83, 0.14, 0.21, 0.79, 0.5, 0.17, 0.46, 0.74, 0.88, 0.94 };
IntComparator c = IntComparator.comparingDouble(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 23) * 1337) % array.length;
assertEquals(c.compare(i, j), Double.compare(array[i], array[j]));
}
}

}