Skip to content

Commit

Permalink
Make it work for tests
Browse files Browse the repository at this point in the history
Co-Authored-By: Jonatan Jäderberg <jonatan.jaderberg@neo4j.com>
Co-Authored-By: Martin Junghanns <martin.junghanns@neo4j.com>
  • Loading branch information
3 people committed Jun 6, 2024
1 parent 2e1741c commit 4c85d72
Show file tree
Hide file tree
Showing 52 changed files with 264 additions and 161 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.code_verification

import com.snowflake.snowpark.{CodeVerification, DataFrame}
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

// verify API Java and Scala API contain same functions
@CodeVerification
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.code_verification

import com.snowflake.snowpark.CodeVerification
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

import scala.collection.mutable

Expand Down
6 changes: 5 additions & 1 deletion src/test/scala/com/snowflake/perf/PerfBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,18 @@ trait PerfBase extends SNTestBase {
test(testName) {
try {
writeResult(testName, timer(func))
succeed
} catch {
case ex: Exception =>
writeResult(testName, -1.0) // -1.0 if failed
throw ex
}
}
} else {
ignore(testName)(func)
ignore(testName) {
func
succeed
}
}
}
}
19 changes: 12 additions & 7 deletions src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.snowflake.snowpark.internal.analyzer.{
import com.snowflake.snowpark.types._
import net.snowflake.client.core.SFSessionProperty
import net.snowflake.client.jdbc.SnowflakeSQLException
import org.scalatest.Assertion

import java.nio.file.Files
import java.sql.{Date, Timestamp}
Expand All @@ -26,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.util.Random
import scala.language.postfixOps

class APIInternalSuite extends TestData {
private val userSchema: StructType = StructType(
Expand Down Expand Up @@ -578,10 +580,11 @@ class APIInternalSuite extends TestData {
testWithAlteredSessionParameter(() => {
import session.implicits._
val schema = StructType(Seq(StructField("ID", LongType)))
val largeData = new ArrayBuffer[Row]()
val largeDataBuf = new ArrayBuffer[Row]()
for (i <- 0 to 1024) {
largeData.append(Row(i.toLong))
largeDataBuf.append(Row(i.toLong))
}
val largeData = largeDataBuf.toSeq
// With specific schema
var df = session.createDataFrame(largeData, schema)
assert(df.snowflakePlan.queries.size == 3)
Expand All @@ -596,7 +599,7 @@ class APIInternalSuite extends TestData {
for (i <- 0 to 1024) {
inferData.append(i.toLong)
}
df = inferData.toDF("id2")
df = inferData.toSeq.toDF("id2")
assert(df.snowflakePlan.queries.size == 3)
assert(df.snowflakePlan.queries(0).sql.trim().startsWith("CREATE SCOPED TEMPORARY TABLE"))
assert(df.snowflakePlan.queries(1).sql.trim().startsWith("INSERT INTO"))
Expand Down Expand Up @@ -823,6 +826,7 @@ class APIInternalSuite extends TestData {
val (rows, meta) = session.conn.getResultAndMetadata(session.sql(query).snowflakePlan)
assert(rows.length == 0 || rows(0).length == meta.size)
}
succeed
}

// reader
Expand Down Expand Up @@ -895,7 +899,7 @@ class APIInternalSuite extends TestData {
assert(ex2.errorCode.equals("0321"))
}

def checkExecuteAndGetQueryId(df: DataFrame): Unit = {
def checkExecuteAndGetQueryId(df: DataFrame): Assertion = {
val query = Query.resultScanQuery(df.executeAndGetQueryId())
val res = query.runQueryGetResult(session.conn, mutable.HashMap.empty[String, String], false)
res.attributes
Expand All @@ -907,7 +911,7 @@ class APIInternalSuite extends TestData {
checkExecuteAndGetQueryIdWithStatementParameter(df)
}

private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Unit = {
private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Assertion = {
val testQueryTagValue = s"test_query_tag_${Random.nextLong().abs}"
val queryId = df.executeAndGetQueryId(Map("QUERY_TAG" -> testQueryTagValue))
val rows = session
Expand Down Expand Up @@ -1007,7 +1011,7 @@ class APIInternalSuite extends TestData {
largeData.append(
Row(1025, null, null, null, null, null, null, null, null, null, null, null, null))

val df = session.createDataFrame(largeData, schema)
val df = session.createDataFrame(largeData.toSeq, schema)
checkExecuteAndGetQueryId(df)

// Statement parameters are applied for all queries.
Expand Down Expand Up @@ -1039,6 +1043,7 @@ class APIInternalSuite extends TestData {
// case 2: test int/boolean parameter
multipleQueriesDF1.executeAndGetQueryId(
Map("STATEMENT_TIMEOUT_IN_SECONDS" -> 100, "USE_CACHED_RESULT" -> false))
succeed
}

test("VariantTypes.getType") {
Expand All @@ -1052,7 +1057,7 @@ class APIInternalSuite extends TestData {
assert(Variant.VariantTypes.getType("Timestamp") == Variant.VariantTypes.Timestamp)
assert(Variant.VariantTypes.getType("Array") == Variant.VariantTypes.Array)
assert(Variant.VariantTypes.getType("Object") == Variant.VariantTypes.Object)
intercept[Exception] { Variant.VariantTypes.getType("not_exist_type") }
assertThrows[Exception] { Variant.VariantTypes.getType("not_exist_type") }
}

test("HasCachedResult doesn't cache again") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,16 @@ class DropTempObjectsSuite extends SNTestBase {
TempObjectType.Table,
"db.schema.tempName1",
TempType.Temporary)
assertTrue(session.getTempObjectMap.contains("db.schema.tempName1"))
assert(session.getTempObjectMap.contains("db.schema.tempName1"))
session.recordTempObjectIfNecessary(
TempObjectType.Table,
"db.schema.tempName2",
TempType.ScopedTemporary)
assertFalse(session.getTempObjectMap.contains("db.schema.tempName2"))
assert(!session.getTempObjectMap.contains("db.schema.tempName2"))
session.recordTempObjectIfNecessary(
TempObjectType.Table,
"db.schema.tempName3",
TempType.Permanent)
assertFalse(session.getTempObjectMap.contains("db.schema.tempName3"))
assert(!session.getTempObjectMap.contains("db.schema.tempName3"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.snowflake.snowpark.internal.ParameterUtils.{
MIN_REQUEST_TIMEOUT_IN_SECONDS,
SnowparkRequestTimeoutInSeconds
}
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

class ErrorMessageSuite extends FunSuite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
emptyChecker(CurrentRow)
emptyChecker(UnspecifiedFrame)
binaryChecker(SpecifiedWindowFrame(RowFrame, _, _))
succeed
}

test("star children and dependent columns") {
Expand Down Expand Up @@ -475,6 +476,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
leafAnalyzerChecker(CurrentRow)
leafAnalyzerChecker(UnspecifiedFrame)
binaryAnalyzerChecker(SpecifiedWindowFrame(RowFrame, _, _))
succeed
}

test("star - analyze") {
Expand All @@ -489,6 +491,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
assert(exp.analyze(x => x) == exp)
assert(exp.analyze(_ => att2) == att2)
leafAnalyzerChecker(Star(Seq.empty))
succeed
}

test("WindowSpecDefinition - analyze") {
Expand Down Expand Up @@ -988,6 +991,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
assert(key.name == "\"COL3\"")
assert(value.name == "\"COL3\"")
}
succeed
}

test("TableDelete - Analyzer") {
Expand Down Expand Up @@ -1125,5 +1129,6 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
leafSimplifierChecker(
SnowflakePlan(Seq.empty, "222", session, None, supportAsyncMode = false))

succeed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import java.util.jar.{JarFile, JarOutputStream}
import java.util.zip.ZipException

import com.snowflake.snowpark.internal.{FatJarBuilder, JavaCodeCompiler}
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.snowflake.snowpark

import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import com.snowflake.snowpark_test._
import java.io.ByteArrayOutputStream

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.snowpark

import com.snowflake.snowpark.internal.{InMemoryClassObject, JavaCodeCompiler, UDFClassPath}
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

class JavaCodeCompilerSuite extends FunSuite {

Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/com/snowflake/snowpark/LoggingSuite.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.snowpark

import com.snowflake.snowpark.internal.Logging
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

class LoggingSuite extends FunSuite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.snowflake.snowpark
import com.snowflake.snowpark.internal.ParameterUtils
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types.IntegerType
import org.scalatest.Assertion

import scala.language.implicitConversions

Expand Down Expand Up @@ -219,7 +220,9 @@ class NewColumnReferenceSuite extends SNTestBase {
case class TestInternalAlias(name: String) extends TestColumnName
implicit def stringToOriginalName(name: String): TestOriginalName =
TestOriginalName(name)
private def verifyOutputName(output: Seq[Attribute], columnNames: Seq[TestColumnName]): Unit = {
private def verifyOutputName(
output: Seq[Attribute],
columnNames: Seq[TestColumnName]): Assertion = {
assert(output.size == columnNames.size)
assert(output.map(_.name).zip(columnNames).forall {
case (name, TestOriginalName(n)) => name == quoteName(n)
Expand Down Expand Up @@ -280,6 +283,7 @@ class NewColumnReferenceSuite extends SNTestBase {
verifyUnaryNode(child => TableUpdate("a", Map.empty, None, Some(child)))
verifyUnaryNode(child => SnowflakeCreateTable("a", SaveMode.Append, Some(child)))
verifyBinaryNode((plan1, plan2) => SimplifiedUnion(Seq(plan1, plan2)))
succeed
}

private val project1 = Project(Seq(Attribute("a", IntegerType)), Range(1, 1, 1))
Expand Down
1 change: 1 addition & 0 deletions src/test/scala/com/snowflake/snowpark/ParameterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ParameterSuite extends SNTestBase {
// no need to verify PKCS#1 format key additionally,
// since all Github Action tests use PKCS#1 key to authenticate with Snowflake server.
ParameterUtils.parsePrivateKey(generatePKCS8Key())
succeed
}

private def generatePKCS8Key(): String = {
Expand Down
9 changes: 4 additions & 5 deletions src/test/scala/com/snowflake/snowpark/ReplSuite.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package com.snowflake.snowpark

import java.io.{BufferedReader, OutputStreamWriter, StringReader}
import java.io.{BufferedReader, OutputStreamWriter, StringReader, PrintWriter => JPrintWriter}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths, StandardCopyOption}

import com.snowflake.snowpark.internal.Utils

import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter._
import scala.tools.nsc.util.stringFromStream
import scala.sys.process._
import scala.tools.nsc.interpreter.shell.{ILoop, ShellConfig}

@UDFTest
class ReplSuite extends TestData {
Expand Down Expand Up @@ -48,15 +47,15 @@ class ReplSuite extends TestData {
Console.withOut(outputStream) {
val input = new BufferedReader(new StringReader(preLoad + code))
val output = new JPrintWriter(new OutputStreamWriter(outputStream))
val repl = new ILoop(input, output)
val settings = new Settings()
if (inMemory) {
settings.processArgumentString("-Yrepl-class-based")
} else {
settings.processArgumentString("-Yrepl-class-based -Yrepl-outdir repl_classes")
}
settings.classpath.value = sys.props("java.class.path")
repl.process(settings)
val repl = new ILoop(ShellConfig(settings), input, output)
repl.run(settings)
}
}.replaceAll("scala> ", "")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,31 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, integers)
assert(attribute.length == integers.length)
integers.indices.foreach(index => assert(attribute(index).dataType == LongType))
succeed
}

test("float data type") {
val floats = Seq("float", "float4", "double", "real")
val attribute = getAttributesWithTypes(tableName, floats)
assert(attribute.length == floats.length)
floats.indices.foreach(index => assert(attribute(index).dataType == DoubleType))
succeed
}

test("string data types") {
val strings = Seq("varchar", "char", "character", "string", "text")
val attribute = getAttributesWithTypes(tableName, strings)
assert(attribute.length == strings.length)
strings.indices.foreach(index => assert(attribute(index).dataType == StringType))
succeed
}

test("binary data types") {
val binaries = Seq("binary", "varbinary")
val attribute = getAttributesWithTypes(tableName, binaries)
assert(attribute.length == binaries.length)
binaries.indices.foreach(index => assert(attribute(index).dataType == BinaryType))
succeed
}

test("logical data type") {
Expand All @@ -69,6 +73,7 @@ class ResultAttributesSuite extends SNTestBase {
assert(attributes.length == 1)
assert(attributes.head.dataType == BooleanType)
dropTable(tableName)
succeed
}

test("date & time data type") {
Expand All @@ -83,6 +88,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, dates.map(_._1))
assert(attribute.length == dates.length)
dates.indices.foreach(index => assert(attribute(index).dataType == dates(index)._2))
succeed
}

test("semi-structured data types") {
Expand All @@ -106,5 +112,6 @@ class ResultAttributesSuite extends SNTestBase {
index =>
assert(attribute(index).dataType ==
ArrayType(StringType)))
succeed
}
}
Loading

0 comments on commit 4c85d72

Please sign in to comment.