Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-5984: Fix TimSort bug causes ArrayOutOfBoundsException #4804

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) {
private void mergeCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
|| (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
if (runLen[n - 1] < runLen[n + 1])
n--;
mergeAt(n);
} else if (runLen[n] <= runLen[n + 1]) {
mergeAt(n);
} else {
} else if (runLen[n] > runLen[n + 1]) {
break; // Invariant is established
}
mergeAt(n);
}
}

Expand Down
111 changes: 111 additions & 0 deletions core/src/test/java/org/apache/spark/util/collection/TestTimSort.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.apache.spark.util.collection;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you still need to include the apache license header here.


import java.util.*;

/**
* This codes generates a int array which fails the standard TimSort, Borrowed from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

borrowed -> obtained

* the reporter of this bug.
*
* http://www.envisage-project.eu/timsort-specification-and-verification/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also link to the github repository with an explicit license https://github.com/abstools/java-timsort-bug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's good practice. For completeness, because this AL2 and because there is no additional NOTICE file, we don't need to do anything else to our LICENSE or NOTICE file.

*/
public class TestTimSort {

private static final int MIN_MERGE = 32;

/**
* Returns an array of integers that demonstrate the bug in TimSort
*/
public static int[] getTimSortBugTestSet(int length) {
int minRun = minRunLength(length);
List<Long> runs = runsJDKWorstCase(minRun, length);
return createArray(runs, length);
}

private static int minRunLength(int n) {
int r = 0; // Becomes 1 if any 1 bits are shifted off
while (n >= MIN_MERGE) {
r |= (n & 1);
n >>= 1;
}
return n + r;
}

private static int[] createArray(List<Long> runs, int length) {
int[] a = new int[length];
Arrays.fill(a, 0);
int endRun = -1;
for (long len : runs)
a[endRun += len] = 1;
a[length - 1] = 0;
return a;
}

/**
* Fills <code>runs</code> with a sequence of run lengths of the form<br>
* Y_n x_{n,1} x_{n,2} ... x_{n,l_n} <br>
* Y_{n-1} x_{n-1,1} x_{n-1,2} ... x_{n-1,l_{n-1}} <br>
* ... <br>
* Y_1 x_{1,1} x_{1,2} ... x_{1,l_1}<br>
* The Y_i's are chosen to satisfy the invariant throughout execution,
* but the x_{i,j}'s are merged (by <code>TimSort.mergeCollapse</code>)
* into an X_i that violates the invariant.
*
* @param length The sum of all run lengths that will be added to <code>runs</code>.
*/
private static List<Long> runsJDKWorstCase(int minRun, int length) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most importantly, this file doesn't contain tests that JUnit will run. Have a look at how other files declare test code with @Test annotations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sean - I think this is used in the scalatest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! can't believe I missed the file at the end here. OK. I see that the test case that gets generated is really big, not something you can paste into the source. Hm, OK well I suggest fixing the license situation here at a minimum

List<Long> runs = new ArrayList<Long>();

long runningTotal = 0, Y = minRun + 4, X = minRun;

while (runningTotal + Y + X <= length) {
runningTotal += X + Y;
generateJDKWrongElem(runs, minRun, X);
runs.add(0, Y);
// X_{i+1} = Y_i + x_{i,1} + 1, since runs.get(1) = x_{i,1}
X = Y + runs.get(1) + 1;
// Y_{i+1} = X_{i+1} + Y_i + 1
Y += X + 1;
}

if (runningTotal + X <= length) {
runningTotal += X;
generateJDKWrongElem(runs, minRun, X);
}

runs.add(length - runningTotal);
return runs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is there any test at all in this file? it seems like it just generates the problem test case.

Maybe you can use it to generate a short test exposing the bug, and create a new, actual test that shows the sort works on it.

Then this code need not exist in Spark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SorterSuite I added a test that uses TestTimSort.java

Yes TestTimSort just generate a int[], but the the array has to be at least 67108864 long, so I thought just posting a huge int[] is not as useful as knowing how the array is generated.

The original codes was written to demonstrate the bug so it had a main(), and some other stuffs, I get rid of those.

I am fine with fixing the license here, if you guys bear with me a bit. I am not that experienced with open source licenses

}

/**
* Adds a sequence x_1, ..., x_n of run lengths to <code>runs</code> such that:<br>
* 1. X = x_1 + ... + x_n <br>
* 2. x_j >= minRun for all j <br>
* 3. x_1 + ... + x_{j-2} < x_j < x_1 + ... + x_{j-1} for all j <br>
* These conditions guarantee that TimSort merges all x_j's one by one
* (resulting in X) using only merges on the second-to-last element.
*
* @param X The sum of the sequence that should be added to runs.
*/
private static void generateJDKWrongElem(List<Long> runs, int minRun, long X) {
for (long newTotal; X >= 2 * minRun + 1; X = newTotal) {
//Default strategy
newTotal = X / 2 + 1;
//Specialized strategies
if (3 * minRun + 3 <= X && X <= 4 * minRun + 1) {
// add x_1=MIN+1, x_2=MIN, x_3=X-newTotal to runs
newTotal = 2 * minRun + 1;
} else if (5 * minRun + 5 <= X && X <= 6 * minRun + 5) {
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=X-newTotal to runs
newTotal = 3 * minRun + 3;
} else if (8 * minRun + 9 <= X && X <= 10 * minRun + 9) {
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=X-newTotal to runs
newTotal = 5 * minRun + 5;
} else if (13 * minRun + 15 <= X && X <= 16 * minRun + 17) {
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=3MIN+4, x_6=X-newTotal to runs
newTotal = 8 * minRun + 9;
}
runs.add(0, X - newTotal);
}
runs.add(0, X);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class SorterSuite extends FunSuite {
}
}

test("bug of TimSort") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to provide a link to the original blog post

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and would be better to include the JIRA ticket number, i.e.

test("SPARK-5984 TimSort bug") {
  ...
}

val data = TestTimSort.getTimSortBugTestSet(67108864)
new Sorter(new IntArraySortDataFormat).sort(data, 0, data.length, Ordering.Int)
(0 to data.length - 2).foreach(i => assert(data(i) <= data(i+1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space before + and after +

}

/** Runs an experiment several times. */
def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
if (skip) {
Expand Down