Skip to content

Commit

Permalink
[MINOR][TESTS] Replace JVM assert with JUnit Assert in tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Use JUnit assertions in tests uniformly, not JVM assert() statements.

### Why are the changes needed?

assert() statements do not produce as useful errors when they fail, and, if they were somehow disabled, would fail to test anything.

### Does this PR introduce any user-facing change?

No. The assertion logic should be identical.

### How was this patch tested?

Existing tests.

Closes #26581 from srowen/assertToJUnit.

Authored-by: Sean Owen <sean.owen@databricks.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
  • Loading branch information
srowen committed Nov 20, 2019
1 parent 23b3c4f commit 1febd37
Show file tree
Hide file tree
Showing 41 changed files with 102 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.JavaUtils;
import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -122,7 +123,7 @@ private void insertFile(String filename) throws IOException {
private void insertFile(String filename, byte[] block) throws IOException {
OutputStream dataStream = null;
File file = ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename);
assert(!file.exists()) : "this test file has been already generated";
Assert.assertFalse("this test file has been already generated", file.exists());
try {
dataStream = new FileOutputStream(
ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import org.scalatest.Assertions._
import org.scalatest.Matchers
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.{Millis, Span}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark

import org.scalatest.Assertions._

import org.apache.spark.benchmark.Benchmark
import org.apache.spark.benchmark.BenchmarkBase
import org.apache.spark.scheduler.CompressedMapStatus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.Try

import org.apache.commons.io.output.TeeOutputStream
import org.apache.commons.lang3.SystemUtils
import org.scalatest.Assertions._

import org.apache.spark.util.Utils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.Assertions._
import org.scalatest.PrivateMethodTester
import org.scalatest.concurrent.Eventually
import org.scalatestplus.mockito.MockitoSugar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.rpc
import scala.collection.mutable.ArrayBuffer

import org.scalactic.TripleEquals
import org.scalatest.Assertions._

class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.concurrent.duration.{Duration, SECONDS}
import scala.reflect.ClassTag

import org.scalactic.TripleEquals
import org.scalatest.Assertions
import org.scalatest.Assertions.AssertionsHelper
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
Expand Down Expand Up @@ -463,7 +464,7 @@ class MockRDD(
override def toString: String = "MockRDD " + id
}

object MockRDD extends AssertionsHelper with TripleEquals {
object MockRDD extends AssertionsHelper with TripleEquals with Assertions {
/**
* make sure all the shuffle dependencies have a consistent number of output partitions
* (mostly to make sure the test setup makes sense, not that Spark itself would get this wrong)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.google.common.util.concurrent.MoreExecutors
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, anyLong}
import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.Assertions._
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.FileAlreadyExistsException
import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyString}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.Assertions._

import org.apache.spark._
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -128,7 +129,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
def removeExecutor(execId: String): Unit = {
executors -= execId
val host = executorIdToHost.get(execId)
assert(host != None)
assert(host.isDefined)
val hostId = host.get
val executorsOnHost = hostToExecutors(hostId)
executorsOnHost -= execId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.{Map => JMap}
import java.util.concurrent.atomic.AtomicBoolean

import com.google.common.collect.ImmutableMap
import org.scalatest.Assertions._
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import org.apache.hadoop.fs.Path
import org.scalatest.Assertions._

import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
package org.apache.spark.sql.kafka010

import java.io.{File, IOException}
import java.lang.{Integer => JInt}
import java.net.{InetAddress, InetSocketAddress}
import java.nio.charset.StandardCharsets
import java.util.{Collections, Map => JMap, Properties, UUID}
import java.util.{Collections, Properties, UUID}
import java.util.concurrent.TimeUnit
import javax.security.auth.login.Configuration

Expand All @@ -41,13 +40,12 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.SaslConfigs
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
import org.apache.zookeeper.server.auth.SASLAuthenticationProvider
import org.scalatest.Assertions._
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.spark.streaming.kinesis;

import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
import org.junit.Assert;
import org.junit.Test;

import org.apache.spark.streaming.kinesis.KinesisInitialPositions.TrimHorizon;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.Seconds;
import org.junit.Test;

public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext {
/**
Expand All @@ -49,13 +51,14 @@ public void testJavaKinesisDStreamBuilder() {
.checkpointInterval(checkpointInterval)
.storageLevel(storageLevel)
.build();
assert(kinesisDStream.streamName() == streamName);
assert(kinesisDStream.endpointUrl() == endpointUrl);
assert(kinesisDStream.regionName() == region);
assert(kinesisDStream.initialPosition().getPosition() == initialPosition.getPosition());
assert(kinesisDStream.checkpointAppName() == appName);
assert(kinesisDStream.checkpointInterval() == checkpointInterval);
assert(kinesisDStream._storageLevel() == storageLevel);
Assert.assertEquals(streamName, kinesisDStream.streamName());
Assert.assertEquals(endpointUrl, kinesisDStream.endpointUrl());
Assert.assertEquals(region, kinesisDStream.regionName());
Assert.assertEquals(initialPosition.getPosition(),
kinesisDStream.initialPosition().getPosition());
Assert.assertEquals(appName, kinesisDStream.checkpointAppName());
Assert.assertEquals(checkpointInterval, kinesisDStream.checkpointInterval());
Assert.assertEquals(storageLevel, kinesisDStream._storageLevel());
ssc.stop();
}

Expand Down Expand Up @@ -83,13 +86,14 @@ public void testJavaKinesisDStreamBuilderOldApi() {
.checkpointInterval(checkpointInterval)
.storageLevel(storageLevel)
.build();
assert(kinesisDStream.streamName() == streamName);
assert(kinesisDStream.endpointUrl() == endpointUrl);
assert(kinesisDStream.regionName() == region);
assert(kinesisDStream.initialPosition().getPosition() == InitialPositionInStream.LATEST);
assert(kinesisDStream.checkpointAppName() == appName);
assert(kinesisDStream.checkpointInterval() == checkpointInterval);
assert(kinesisDStream._storageLevel() == storageLevel);
Assert.assertEquals(streamName, kinesisDStream.streamName());
Assert.assertEquals(endpointUrl, kinesisDStream.endpointUrl());
Assert.assertEquals(region, kinesisDStream.regionName());
Assert.assertEquals(InitialPositionInStream.LATEST,
kinesisDStream.initialPosition().getPosition());
Assert.assertEquals(appName, kinesisDStream.checkpointAppName());
Assert.assertEquals(checkpointInterval, kinesisDStream.checkpointInterval());
Assert.assertEquals(storageLevel, kinesisDStream._storageLevel());
ssc.stop();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.graphx.util

import org.apache.hadoop.fs.Path
import org.scalatest.Assertions

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
Expand Down Expand Up @@ -88,7 +89,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContex
}
}

private object PeriodicGraphCheckpointerSuite {
private object PeriodicGraphCheckpointerSuite extends Assertions {
private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER

case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.spark.sql.Encoders;
import org.junit.Assert;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
Expand Down Expand Up @@ -60,7 +61,7 @@ public void testKSTestCDF() {
.test(dataset, "sample", stdNormalCDF).head();
double pValue1 = results.getDouble(0);
// Cannot reject null hypothesis
assert(pValue1 > pThreshold);
Assert.assertTrue(pValue1 > pThreshold);
}

@Test
Expand All @@ -72,6 +73,6 @@ public void testKSTestNamedDistribution() {
.test(dataset, "sample", "norm", 0.0, 1.0).head();
double pValue1 = results.getDouble(0);
// Cannot reject null hypothesis
assert(pValue1 > pThreshold);
Assert.assertTrue(pValue1 > pThreshold);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import scala.util.Random

import breeze.linalg.{DenseVector => BDV}
import org.scalatest.Assertions._

import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.collection.JavaConverters._
import scala.util.Random
import scala.util.control.Breaks._

import org.scalatest.Assertions._

import org.apache.spark.SparkException
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import org.scalatest.Assertions._

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.classification

import org.scalatest.Assertions._

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.feature

import org.scalatest.Assertions._

import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils}
import org.apache.spark.sql.Dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import scala.util.Random
import scala.util.control.Breaks._

import org.scalatest.Assertions._
import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import java.util.{ArrayList => JArrayList}

import breeze.linalg.{argmax, argtopk, max, DenseMatrix => BDM}
import org.scalatest.Assertions

import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree

import scala.collection.mutable

import org.scalatest.Assertions._

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.mesos.protobuf.ByteString
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{times, verify}
import org.scalatest.Assertions._

import org.apache.spark.deploy.mesos.config.MesosSecretConfig

Expand Down Expand Up @@ -161,12 +162,14 @@ object Utils {
val variableOne = envVars.filter(_.getName == "USER").head
assert(variableOne.getSecret.isInitialized)
assert(variableOne.getSecret.getType == Secret.Type.VALUE)
assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes))
assert(variableOne.getSecret.getValue.getData ==
ByteString.copyFrom("user".getBytes))
assert(variableOne.getType == Environment.Variable.Type.SECRET)
val variableTwo = envVars.filter(_.getName == "PASSWORD").head
assert(variableTwo.getSecret.isInitialized)
assert(variableTwo.getSecret.getType == Secret.Type.VALUE)
assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes))
assert(variableTwo.getSecret.getValue.getData ==
ByteString.copyFrom("password".getBytes))
assert(variableTwo.getType == Environment.Variable.Type.SECRET)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@

import java.util.Locale;

import org.junit.Assert;
import org.junit.Test;

public class JavaOutputModeSuite {

@Test
public void testOutputModes() {
OutputMode o1 = OutputMode.Append();
assert(o1.toString().toLowerCase(Locale.ROOT).contains("append"));
Assert.assertTrue(o1.toString().toLowerCase(Locale.ROOT).contains("append"));
OutputMode o2 = OutputMode.Complete();
assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete"));
Assert.assertTrue(o2.toString().toLowerCase(Locale.ROOT).contains("complete"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import org.scalatest.Assertions._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.time.{Duration, Instant, LocalDate}
import java.util.concurrent.TimeUnit

import org.scalacheck.{Arbitrary, Gen}
import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
import org.apache.spark.sql.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable

import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform}
Expand Down
Loading

0 comments on commit 1febd37

Please sign in to comment.