Skip to content

Commit

Permalink
[SPARK-48386][TESTS] Replace JVM assert with JUnit Assert in tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The pr aims to replace `JVM assert` with `JUnit Assert` in tests.

### Why are the changes needed?
assert() statements do not produce as useful errors when they fail, and, if they were somehow disabled, would fail to test anything.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Manually test.
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46698 from panbingkun/minor_assert.

Authored-by: panbingkun <panbingkun@baidu.com>
Signed-off-by: yangjie01 <yangjie01@baidu.com>
  • Loading branch information
panbingkun authored and LuciferYang committed May 23, 2024
1 parent a48365d commit 5df9a08
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public void testChunkedStream() throws Exception {

// Validate we read data correctly
assertEquals(bodyResult.readableBytes(), chunkSize);
assert(bodyResult.readableBytes() < (randomData.length - readIndex));
assertTrue(bodyResult.readableBytes() < (randomData.length - readIndex));
while (bodyResult.readableBytes() > 0) {
assertEquals(bodyResult.readByte(), randomData[readIndex++]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException {
verify(listener, timeout(5000)).onBlockTransferSuccess("b0", block0);
verify(listener).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == 0);
assertEquals(0, _retryingBlockTransferor.getRetryCount());
}

@Test
Expand All @@ -310,7 +310,7 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException);
verify(listener, times(3)).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES);
assertEquals(MAX_RETRIES, _retryingBlockTransferor.getRetryCount());
}

@Test
Expand Down Expand Up @@ -339,7 +339,7 @@ public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedE
// This should be equal to 1 because after the SASL exception is retried,
// retryCount should be set back to 0. Then after that b1 encounters an
// exception that is retried.
assert(_retryingBlockTransferor.getRetryCount() == 1);
assertEquals(1, _retryingBlockTransferor.getRetryCount());
}

@Test
Expand Down Expand Up @@ -368,7 +368,7 @@ public void testIOExceptionFailsConnectionEvenWithSaslException()
verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslExceptionFinal);
verify(listener, atLeastOnce()).getTransferType();
verifyNoMoreInteractions(listener);
assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES);
assertEquals(MAX_RETRIES, _retryingBlockTransferor.getRetryCount());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.spark.internal.LogKeys;
import org.apache.spark.internal.MDC;

import static org.junit.jupiter.api.Assertions.assertTrue;

public abstract class SparkLoggerSuiteBase {

abstract SparkLogger logger();
Expand Down Expand Up @@ -104,8 +106,8 @@ public void testBasicMsgLogger() {
Pair.of(Level.DEBUG, debugFn),
Pair.of(Level.TRACE, traceFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForBasicMsg(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForBasicMsg(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -127,8 +129,8 @@ public void testBasicLoggerWithException() {
Pair.of(Level.DEBUG, debugFn),
Pair.of(Level.TRACE, traceFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForBasicMsgWithException(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForBasicMsgWithException(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -147,8 +149,8 @@ public void testLoggerWithMDC() {
Pair.of(Level.WARN, warnFn),
Pair.of(Level.INFO, infoFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDC(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDC(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -165,8 +167,8 @@ public void testLoggerWithMDCs() {
Pair.of(Level.WARN, warnFn),
Pair.of(Level.INFO, infoFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCs(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCs(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -184,8 +186,8 @@ public void testLoggerWithMDCsAndException() {
Pair.of(Level.WARN, warnFn),
Pair.of(Level.INFO, infoFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCsAndException(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCsAndException(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -202,8 +204,8 @@ public void testLoggerWithMDCValueIsNull() {
Pair.of(Level.WARN, warnFn),
Pair.of(Level.INFO, infoFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCValueIsNull(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForMsgWithMDCValueIsNull(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -220,8 +222,8 @@ public void testLoggerWithExternalSystemCustomLogKey() {
Pair.of(Level.WARN, warnFn),
Pair.of(Level.INFO, infoFn)).forEach(pair -> {
try {
assert (captureLogOutput(pair.getRight()).matches(
expectedPatternForExternalSystemCustomLogKey(pair.getLeft())));
assertTrue(captureLogOutput(pair.getRight()).matches(
expectedPatternForExternalSystemCustomLogKey(pair.getLeft())));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.streaming.*;

import static org.junit.jupiter.api.Assertions.*;

/**
* A test stateful processor used with transformWithState arbitrary stateful operator in
* Structured Streaming. The processor primarily aims to test various functionality of the Java API
Expand Down Expand Up @@ -74,21 +76,21 @@ public scala.collection.Iterator<String> handleInputRows(
} else {
keyCountMap.updateValue(value, 1L);
}
assert(keyCountMap.containsKey(value));
assertTrue(keyCountMap.containsKey(value));
keysList.appendValue(value);
sb.append(value);
}

scala.collection.Iterator<String> keys = keysList.get();
while (keys.hasNext()) {
String keyVal = keys.next();
assert(keyCountMap.containsKey(keyVal));
assert(keyCountMap.getValue(keyVal) > 0);
assertTrue(keyCountMap.containsKey(keyVal));
assertTrue(keyCountMap.getValue(keyVal) > 0);
}

count += numRows;
countState.update(count);
assert (countState.get() == count);
assertEquals(count, (long) countState.get());

result.add(sb.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.streaming.*;

import static org.junit.jupiter.api.Assertions.assertFalse;

/**
* A test stateful processor concatenates all input rows for a key and emits the result.
* Primarily used for testing the Java API for arbitrary stateful operator in structured streaming
Expand Down Expand Up @@ -71,7 +73,7 @@ public scala.collection.Iterator<String> handleInputRows(
}

testState.clear();
assert(testState.exists() == false);
assertFalse(testState.exists());

result.add(sb.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

import static org.junit.jupiter.api.Assertions.assertInstanceOf;

public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source {

@Override
Expand Down Expand Up @@ -66,9 +68,9 @@ public StructType readSchema() {
public Predicate[] pushPredicates(Predicate[] predicates) {
Predicate[] supported = Arrays.stream(predicates).filter(f -> {
if (f.name().equals(">")) {
assert(f.children()[0] instanceof FieldReference);
assertInstanceOf(FieldReference.class, f.children()[0]);
FieldReference column = (FieldReference) f.children()[0];
assert(f.children()[1] instanceof LiteralValue);
assertInstanceOf(LiteralValue.class, f.children()[1]);
Literal value = (Literal) f.children()[1];
return column.describe().equals("i") && value.value() instanceof Integer;
} else {
Expand All @@ -78,9 +80,9 @@ public Predicate[] pushPredicates(Predicate[] predicates) {

Predicate[] unsupported = Arrays.stream(predicates).filter(f -> {
if (f.name().equals(">")) {
assert(f.children()[0] instanceof FieldReference);
assertInstanceOf(FieldReference.class, f.children()[0]);
FieldReference column = (FieldReference) f.children()[0];
assert(f.children()[1] instanceof LiteralValue);
assertInstanceOf(LiteralValue.class, f.children()[1]);
Literal value = (LiteralValue) f.children()[1];
return !column.describe().equals("i") || !(value.value() instanceof Integer);
} else {
Expand Down Expand Up @@ -125,9 +127,9 @@ public InputPartition[] planInputPartitions() {
Integer lowerBound = null;
for (Predicate predicate : predicates) {
if (predicate.name().equals(">")) {
assert(predicate.children()[0] instanceof FieldReference);
assertInstanceOf(FieldReference.class, predicate.children()[0]);
FieldReference column = (FieldReference) predicate.children()[0];
assert(predicate.children()[1] instanceof LiteralValue);
assertInstanceOf(LiteralValue.class, predicate.children()[1]);
Literal value = (Literal) predicate.children()[1];
if ("i".equals(column.describe()) && value.value() instanceof Integer integer) {
lowerBound = integer;
Expand Down

0 comments on commit 5df9a08

Please sign in to comment.