diff --git a/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java new file mode 100644 index 0000000000000..4d76aa3adbc81 --- /dev/null +++ b/common/tags/src/test/java/org/apache/spark/tags/ExtendedRedshiftTest.java @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; + +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedRedshiftTest { } diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 6be1c72bc6cfb..d9022e241b37e 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -103,3 +103,4 @@ org.apache.spark.scheduler.ExternalClusterManager org.apache.spark.deploy.yarn.security.ServiceCredentialProvider spark-warehouse structured-streaming/* +install-redshift-jdbc.sh diff --git a/dev/install-redshift-jdbc.sh b/dev/install-redshift-jdbc.sh new file mode 100755 index 0000000000000..8719d0f6c5632 --- /dev/null +++ b/dev/install-redshift-jdbc.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -e + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" + +cd /tmp + +VERSION='1.1.7.1007' +FILENAME="RedshiftJDBC4-$VERSION.jar" + +wget "https://s3.amazonaws.com/redshift-downloads/drivers/$FILENAME" + +$SPARK_ROOT_DIR/build/mvn install:install-file \ + -Dfile=$FILENAME \ + -DgroupId=com.amazonaws \ + -DartifactId=redshift.jdbc4 \ + -Dversion=$VERSION \ + -Dpackaging=jar + diff --git a/dev/run-tests.py b/dev/run-tests.py index a9692ab0c1350..1be74fa25aac1 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -111,7 +111,8 @@ def determine_modules_to_test(changed_modules): >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'sql-kafka-0-8', 'examples', - 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] + 'hive-thriftserver', 'pyspark-sql', 'redshift', 'sparkr', 'pyspark-mllib', + 'redshift-integration-tests', 'pyspark-ml'] """ modules_to_test = set() for module in changed_modules: @@ -512,6 +513,8 @@ def main(): test_env = "amplab_jenkins" # add path for Python3 in Jenkins if we're calling from a Jenkins machine os.environ["PATH"] = "/home/anaconda/envs/py3k/bin:" + os.environ.get("PATH") + # Install Redshift JDBC + run_cmd([os.path.join(SPARK_HOME, "dev", "install-redshift-jdbc.sh")]) else: # else we're running locally and can use local settings build_tool = "sbt" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b7fc30854f13b..1cb01128b2d91 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -168,6 +168,31 @@ def __hash__(self): ] ) +redshift = Module( + name="redshift", + dependencies=[avro, sql], + source_file_regexes=[ + "external/redshift", + ], + sbt_test_goals=[ + "redshift/test", + ], + test_tags=[ + "org.apache.spark.tags.ExtendedRedshiftTest" + ], +) + +redshift_integration_tests = Module( + name="redshift-integration-tests", + dependencies=[redshift], + source_file_regexes=[ + "external/redshift-integration-tests", + ], + sbt_test_goals=[ + "redshift-integration-tests/test", + ], +) + sql_kafka = Module( name="sql-kafka-0-10", dependencies=[sql], diff --git a/external/redshift-integration-tests/pom.xml b/external/redshift-integration-tests/pom.xml new file mode 100644 index 0000000000000..654b2354ee48b --- /dev/null +++ b/external/redshift-integration-tests/pom.xml @@ -0,0 +1,144 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.1.0 + ../../pom.xml + + + com.databricks + spark-redshift-integration-tests_2.11 + + redshift-integration-tests + + jar + Spark Redshift Integration Tests + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + com.databricks + spark-avro_${scala.binary.version} + ${project.version} + provided + + + com.databricks + spark-redshift_${scala.binary.version} + ${project.version} + provided + + + com.databricks + spark-redshift_${scala.binary.version} + ${project.version} + test-jar + test + + + + com.amazonaws + aws-java-sdk-core + 1.9.40 + provided + + + com.amazonaws + aws-java-sdk-s3 + 1.9.40 + provided + + + com.amazonaws + aws-java-sdk-sts + 1.9.40 + provided + + + com.eclipsesource.minimal-json + minimal-json + 0.9.4 + compile + + + org.apache.hadoop + hadoop-client + test + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + test + + + + com.amazonaws + redshift.jdbc4 + 1.1.7.1007 + jar + test + + + org.mockito + mockito-core + test + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala new file mode 100755 index 0000000000000..e36d80a2c4064 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/AWSCredentialsInUriIntegrationSuite.scala @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI + +import org.apache.spark.SparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * This suite performs basic integration tests where the AWS credentials have been + * encoded into the tempdir URI rather than being set in the Hadoop configuration. + */ +@ExtendedRedshiftTest +class AWSCredentialsInUriIntegrationSuite extends IntegrationSuiteBase { + + override protected val tempDir: String = { + val uri = new URI(AWS_S3_SCRATCH_SPACE + randomSuffix + "/") + new URI( + uri.getScheme, + s"$AWS_ACCESS_KEY_ID:$AWS_SECRET_ACCESS_KEY", + uri.getHost, + uri.getPort, + uri.getPath, + uri.getQuery, + uri.getFragment).toString + } + + + // Override this method so that we do not set the credentials in sc.hadoopConf. + override def beforeAll(): Unit = { + assert(tempDir.contains("AKIA"), "tempdir did not contain AWS credentials") + assert(!AWS_SECRET_ACCESS_KEY.contains("/"), "AWS secret key should not contain slash") + sc = new SparkContext("local", getClass.getSimpleName) + conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None) + } + + test("roundtrip save and load") { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), + StructType(StructField("foo", IntegerType) :: Nil)) + testRoundtripSaveAndLoad(s"roundtrip_save_and_load_$randomSuffix", df) + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala new file mode 100755 index 0000000000000..c27d75dcf0d98 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/ColumnMetadataSuite.scala @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.SQLException + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * End-to-end tests of features which depend on per-column metadata (such as comments, maxlength). + */ +@ExtendedRedshiftTest +class ColumnMetadataSuite extends IntegrationSuiteBase { + + test("configuring maxlength on string columns") { + val tableName = s"configuring_maxlength_on_string_column_$randomSuffix" + try { + val metadata = new MetadataBuilder().putLong("maxlength", 512).build() + val schema = StructType( + StructField("x", StringType, metadata = metadata) :: Nil) + write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), schema)) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 512))) + // This append should fail due to the string being longer than the maxlength + intercept[SQLException] { + write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 513))), schema)) + .option("dbtable", tableName) + .mode(SaveMode.Append) + .save() + } + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("configuring compression on columns") { + val tableName = s"configuring_compression_on_columns_$randomSuffix" + try { + val metadata = new MetadataBuilder().putString("encoding", "LZO").build() + val schema = StructType( + StructField("x", StringType, metadata = metadata) :: Nil) + write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema)) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128))) + val encodingDF = sqlContext.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", + s"""(SELECT "column", lower(encoding) FROM pg_table_def WHERE tablename='$tableName')""") + .load() + checkAnswer(encodingDF, Seq(Row("x", "lzo"))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("configuring comments on columns") { + val tableName = s"configuring_comments_on_columns_$randomSuffix" + try { + val metadata = new MetadataBuilder().putString("description", "Hello Column").build() + val schema = StructType( + StructField("x", StringType, metadata = metadata) :: Nil) + write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema)) + .option("dbtable", tableName) + .option("description", "Hello Table") + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128))) + val tableDF = sqlContext.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", s"(SELECT pg_catalog.obj_description('$tableName'::regclass))") + .load() + checkAnswer(tableDF, Seq(Row("Hello Table"))) + val commentQuery = + s""" + |(SELECT c.column_name, pgd.description + |FROM pg_catalog.pg_statio_all_tables st + |INNER JOIN pg_catalog.pg_description pgd + | ON (pgd.objoid=st.relid) + |INNER JOIN information_schema.columns c + | ON (pgd.objsubid=c.ordinal_position AND c.table_name=st.relname) + |WHERE c.table_name='$tableName') + """.stripMargin + val columnDF = sqlContext.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", commentQuery) + .load() + checkAnswer(columnDF, Seq(Row("x", "Hello Column"))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala new file mode 100755 index 0000000000000..b4fb34af86c62 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/CrossRegionIntegrationSuite.scala @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import com.amazonaws.auth.BasicAWSCredentials +import com.amazonaws.services.s3.AmazonS3Client + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * Integration tests where the Redshift cluster and the S3 bucket are in different AWS regions. + */ +@ExtendedRedshiftTest +class CrossRegionIntegrationSuite extends IntegrationSuiteBase { + + protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE: String = + loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE") + require(AWS_S3_CROSS_REGION_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL") + + override protected val tempDir: String = AWS_S3_CROSS_REGION_SCRATCH_SPACE + randomSuffix + "/" + + test("write") { + val bucketRegion = Utils.getRegionForS3Bucket( + tempDir, + new AmazonS3Client(new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY))).get + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), + StructType(StructField("foo", IntegerType) :: Nil)) + val tableName = s"roundtrip_save_and_load_$randomSuffix" + try { + write(df) + .option("dbtable", tableName) + .option("extracopyoptions", s"region '$bucketRegion'") + .save() + // Check that the table exists. It appears that creating a table in one connection then + // immediately querying for existence from another connection may result in spurious "table + // doesn't exist" errors; this caused the "save with all empty partitions" test to become + // flaky (see #146). To work around this, add a small sleep and check again: + if (!DefaultJDBCWrapper.tableExists(conn, tableName)) { + Thread.sleep(1000) + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + } + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala new file mode 100755 index 0000000000000..e3078b9457eaf --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see + * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html + */ +@ExtendedRedshiftTest +class DecimalIntegrationSuite extends IntegrationSuiteBase { + + private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = { + test(s"reading DECIMAL($precision, $scale)") { + val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix" + val expectedRows = decimalStrings.map { d => + if (d == null) { + Row(null) + } else { + Row(Conversions.createRedshiftDecimalFormat().parse(d).asInstanceOf[java.math.BigDecimal]) + } + } + try { + conn.createStatement().executeUpdate( + s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))") + for (x <- decimalStrings) { + conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)") + } + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = read.option("dbtable", tableName).load() + checkAnswer(loadedDf, expectedRows) + checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + } + + testReadingDecimals(19, 0, Seq( + // Max and min values of DECIMAL(19, 0) column according to Redshift docs: + "9223372036854775807", // 2^63 - 1 + "-9223372036854775807", + "0", + "12345678910", + null + )) + + testReadingDecimals(19, 4, Seq( + "922337203685477.5807", + "-922337203685477.5807", + "0", + "1234567.8910", + null + )) + + testReadingDecimals(38, 4, Seq( + "922337203685477.5808", + "9999999999999999999999999999999999.0000", + "-9999999999999999999999999999999999.0000", + "0", + "1234567.8910", + null + )) + + test("Decimal precision is preserved when reading from query (regression test for issue #203)") { + withTempRedshiftTable("issue203") { tableName => + try { + conn.createStatement().executeUpdate(s"CREATE TABLE $tableName (foo BIGINT)") + conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES (91593373)") + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val df = read + .option("query", s"select foo / 1000000.0 from $tableName limit 1") + .load() + val res: Double = df.collect().toSeq.head.getDecimal(0).doubleValue() + assert(res === (91593373L / 1000000.0) +- 0.01) + assert(df.schema.fields.head.dataType === DecimalType(28, 8)) + } + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala new file mode 100755 index 0000000000000..c8c298166f5ef --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IAMIntegrationSuite.scala @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.SQLException + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * Integration tests for configuring Redshift to access S3 using Amazon IAM roles. + */ +@ExtendedRedshiftTest +class IAMIntegrationSuite extends IntegrationSuiteBase { + + private val IAM_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN") + + test("roundtrip save and load") { + val tableName = s"iam_roundtrip_save_and_load$randomSuffix" + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + try { + write(df) + .option("dbtable", tableName) + .option("forward_spark_s3_credentials", "false") + .option("aws_iam_role", IAM_ROLE_ARN) + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = read + .option("dbtable", tableName) + .option("forward_spark_s3_credentials", "false") + .option("aws_iam_role", IAM_ROLE_ARN) + .load() + assert(loadedDf.schema.length === 1) + assert(loadedDf.columns === Seq("a")) + checkAnswer(loadedDf, Seq(Row(1))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("load fails if IAM role cannot be assumed") { + val tableName = s"iam_load_fails_if_role_cannot_be_assumed$randomSuffix" + try { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + val err = intercept[SQLException] { + write(df) + .option("dbtable", tableName) + .option("forward_spark_s3_credentials", "false") + .option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix") + .mode(SaveMode.ErrorIfExists) + .save() + } + assert(err.getCause.getMessage.contains("is not authorized to assume IAM Role")) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala new file mode 100644 index 0000000000000..6c4026cab2935 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI +import java.sql.Connection + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.s3native.NativeS3FileSystem +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveContext +import org.apache.spark.sql.types.StructType + +/** + * Base class for writing integration tests which run against a real Redshift cluster. + */ +trait IntegrationSuiteBase + extends QueryTest + with Matchers + with BeforeAndAfterEach { + + protected def loadConfigFromEnv(envVarName: String): String = { + Option(System.getenv(envVarName)).getOrElse { + fail(s"Must set $envVarName environment variable") + } + } + + // The following configurations must be set in order to run these tests. In Travis, these + // environment variables are set using Travis's encrypted environment variables feature: + // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables + + // JDBC URL listed in the AWS console (should not contain username and password). + protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL") + protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER") + protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD") + protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID") + protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY") + // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space'). + protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE") + require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL") + + protected def jdbcUrl: String = { + s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD" + } + + /** + * Random suffix appended appended to table and directory names in order to avoid collisions + * between separate Travis builds. + */ + protected val randomSuffix: String = Math.abs(Random.nextLong()).toString + + protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/" + + /** + * Spark Context with Hadoop file overridden to point at our local test data file for this suite, + * no-matter what temp directory was generated and requested. + */ + protected var sc: SparkContext = _ + protected var sqlContext: SQLContext = _ + protected var conn: Connection = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sc = new SparkContext("local", "RedshiftSourceSuite") + // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials: + sc.hadoopConfiguration.setBoolean("fs.s3.impl.disable.cache", true) + sc.hadoopConfiguration.setBoolean("fs.s3n.impl.disable.cache", true) + sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) + sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) + conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None) + } + + override def afterAll(): Unit = { + try { + val conf = new Configuration(false) + conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) + conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) + // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials: + conf.setBoolean("fs.s3.impl.disable.cache", true) + conf.setBoolean("fs.s3n.impl.disable.cache", true) + conf.set("fs.s3.impl", classOf[NativeS3FileSystem].getCanonicalName) + conf.set("fs.s3n.impl", classOf[NativeS3FileSystem].getCanonicalName) + val fs = FileSystem.get(URI.create(tempDir), conf) + fs.delete(new Path(tempDir), true) + fs.close() + } finally { + try { + conn.close() + } finally { + try { + sc.stop() + } finally { + super.afterAll() + } + } + } + } + + override protected def beforeEach(): Unit = { + super.beforeEach() + sqlContext = new TestHiveContext(sc, loadTestTables = false) + } + + /** + * Create a new DataFrameReader using common options for reading from Redshift. + */ + protected def read: DataFrameReader = { + sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("tempdir", tempDir) + .option("forward_spark_s3_credentials", "true") + } + /** + * Create a new DataFrameWriter using common options for writing to Redshift. + */ + protected def write(df: DataFrame): DataFrameWriter[Row] = { + df.write + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("tempdir", tempDir) + .option("forward_spark_s3_credentials", "true") + } + + protected def createTestDataInRedshift(tableName: String): Unit = { + conn.createStatement().executeUpdate( + s""" + |create table $tableName ( + |testbyte int2, + |testbool boolean, + |testdate date, + |testdouble float8, + |testfloat float4, + |testint int4, + |testlong int8, + |testshort int2, + |teststring varchar(256), + |testtimestamp timestamp + |) + """.stripMargin + ) + // scalastyle:off + conn.createStatement().executeUpdate( + s""" + |insert into $tableName values + |(null, null, null, null, null, null, null, null, null, null), + |(0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', '2015-07-03 00:00:00.000'), + |(0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', null), + |(1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', '2015-07-02 00:00:00.000'), + |(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001') + """.stripMargin + ) + // scalastyle:on + conn.commit() + } + + protected def withTempRedshiftTable[T](namePrefix: String)(body: String => T): T = { + val tableName = s"$namePrefix$randomSuffix" + try { + body(tableName) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + /** + * Save the given DataFrame to Redshift, then load the results back into a DataFrame and check + * that the returned DataFrame matches the one that we saved. + * + * @param tableName the table name to use + * @param df the DataFrame to save + * @param expectedSchemaAfterLoad if specified, the expected schema after loading the data back + * from Redshift. This should be used in cases where you expect + * the schema to differ due to reasons like case-sensitivity. + * @param saveMode the [[SaveMode]] to use when writing data back to Redshift + */ + def testRoundtripSaveAndLoad( + tableName: String, + df: DataFrame, + expectedSchemaAfterLoad: Option[StructType] = None, + saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = { + try { + write(df) + .option("dbtable", tableName) + .mode(saveMode) + .save() + // Check that the table exists. It appears that creating a table in one connection then + // immediately querying for existence from another connection may result in spurious "table + // doesn't exist" errors; this caused the "save with all empty partitions" test to become + // flaky (see #146). To work around this, add a small sleep and check again: + if (!DefaultJDBCWrapper.tableExists(conn, tableName)) { + Thread.sleep(1000) + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + } + val loadedDf = read.option("dbtable", tableName).load() + assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema)) + checkAnswer(loadedDf, df.collect()) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala new file mode 100755 index 0000000000000..3513ec48d56f7 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * Basic integration tests with the Postgres JDBC driver. + */ +@ExtendedRedshiftTest +class PostgresDriverIntegrationSuite extends IntegrationSuiteBase { + + override def jdbcUrl: String = { + super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql") + } + + test("postgresql driver takes precedence for jdbc:postgresql:// URIs") { + val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None) + try { + // TODO(josh): this is slightly different than what was done in open-source spark-redshift. + // This is due to conflicting PG driver being pulled in via transitive Spark test deps. + // We should consider removing the postgres driver support entirely in our Databricks internal + // version. + assert(conn.getClass.getName === "org.postgresql.jdbc.PgConnection") + } finally { + conn.close() + } + } + + test("roundtrip save and load") { + conn.setAutoCommit(false) // TODO(josh): Hack needed due to different PG driver version + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), + StructType(StructField("foo", IntegerType) :: Nil)) + testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala new file mode 100755 index 0000000000000..cde68da81d1a7 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * This suite performs basic integration tests where the Redshift credentials have been + * specified via `spark-redshift`'s configuration rather than as part of the JDBC URL. + */ +@ExtendedRedshiftTest +class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase { + + test("roundtrip save and load") { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), + StructType(StructField("foo", IntegerType) :: Nil)) + val tableName = s"roundtrip_save_and_load_$randomSuffix" + try { + write(df) + .option("url", AWS_REDSHIFT_JDBC_URL) + .option("user", AWS_REDSHIFT_USER) + .option("password", AWS_REDSHIFT_PASSWORD) + .option("dbtable", tableName) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = read + .option("url", AWS_REDSHIFT_JDBC_URL) + .option("user", AWS_REDSHIFT_USER) + .option("password", AWS_REDSHIFT_PASSWORD) + .option("dbtable", tableName) + .load() + assert(loadedDf.schema === df.schema) + checkAnswer(loadedDf, df.collect()) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala new file mode 100755 index 0000000000000..dfd98281043f5 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala @@ -0,0 +1,243 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.types.LongType +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * End-to-end tests of functionality which only impacts the read path (e.g. filter pushdown). + */ +@ExtendedRedshiftTest +class RedshiftReadSuite extends IntegrationSuiteBase { + + private val test_table: String = s"read_suite_test_table_$randomSuffix" + + override def beforeAll(): Unit = { + super.beforeAll() + conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() + conn.commit() + createTestDataInRedshift(test_table) + } + + override def afterAll(): Unit = { + try { + conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() + conn.commit() + } finally { + super.afterAll() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + read.option("dbtable", test_table).load().createOrReplaceTempView("test_table") + } + + test("DefaultSource can load Redshift UNLOAD output to a DataFrame") { + checkAnswer( + sqlContext.sql("select * from test_table"), + TestUtils.expectedData) + } + + test("count() on DataFrame created from a Redshift table") { + checkAnswer( + sqlContext.sql("select count(*) from test_table"), + Seq(Row(TestUtils.expectedData.length)) + ) + } + + test("count() on DataFrame created from a Redshift query") { + val loadedDf = + // scalastyle:off + read.option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'").load() + // scalastyle:on + checkAnswer( + loadedDf.selectExpr("count(*)"), + Seq(Row(1)) + ) + } + + test("backslashes in queries/subqueries are escaped (regression test for #215)") { + val loadedDf = + read.option("query", s"select replace(teststring, '\\\\', '') as col from $test_table").load() + checkAnswer( + loadedDf.filter("col = 'asdf'"), + Seq(Row("asdf")) + ) + } + + test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { + // scalastyle:off + val query = + s""" + |(select testbyte, testbool + |from $test_table + |where testbool = true + | and teststring = 'Unicode''s樂趣' + | and testdouble = 1234152.12312498 + | and testfloat = 1.0 + | and testint = 42) + """.stripMargin + // scalastyle:on + checkAnswer(read.option("dbtable", query).load(), Seq(Row(1, true))) + } + + test("Can load output when 'query' is specified instead of 'dbtable'") { + // scalastyle:off + val query = + s""" + |select testbyte, testbool + |from $test_table + |where testbool = true + | and teststring = 'Unicode''s樂趣' + | and testdouble = 1234152.12312498 + | and testfloat = 1.0 + | and testint = 42 + """.stripMargin + // scalastyle:on + checkAnswer(read.option("query", query).load(), Seq(Row(1, true))) + } + + test("Can load output of Redshift aggregation queries") { + checkAnswer( + read.option("query", s"select testbool, count(*) from $test_table group by testbool").load(), + Seq(Row(true, 1), Row(false, 2), Row(null, 2))) + } + + test("multiple scans on same table") { + // .rdd() forces the first query to be unloaded from Redshift + val rdd1 = sqlContext.sql("select testint from test_table").rdd + // Similarly, this also forces an unload: + sqlContext.sql("select testdouble from test_table").rdd + // If the unloads were performed into the same directory then this call would fail: the + // second unload from rdd2 would have overwritten the integers with doubles, so we'd get + // a NumberFormatException. + rdd1.count() + } + + test("DefaultSource supports simple column filtering") { + checkAnswer( + sqlContext.sql("select testbyte, testbool from test_table"), + Seq( + Row(null, null), + Row(0.toByte, null), + Row(0.toByte, false), + Row(1.toByte, false), + Row(1.toByte, true))) + } + + test("query with pruned and filtered scans") { + // scalastyle:off + checkAnswer( + sqlContext.sql( + """ + |select testbyte, testbool + |from test_table + |where testbool = true + | and teststring = "Unicode's樂趣" + | and testdouble = 1234152.12312498 + | and testfloat = 1.0 + | and testint = 42 + """.stripMargin), + Seq(Row(1, true))) + // scalastyle:on + } + + test("RedshiftRelation implements Spark 1.6+'s unhandledFilters API") { + assume(org.apache.spark.SPARK_VERSION.take(3) >= "1.6") + val df = sqlContext.sql("select testbool from test_table where testbool = true") + val physicalPlan = df.queryExecution.sparkPlan + physicalPlan.collectFirst { case f: execution.FilterExec => f }.foreach { filter => + fail(s"Filter should have been eliminated:\n${df.queryExecution}") + } + } + + test("filtering based on date constants (regression test for #152)") { + val date = TestUtils.toDate(year = 2015, zeroBasedMonth = 6, date = 3) + val df = sqlContext.sql("select testdate from test_table") + + checkAnswer(df.filter(df("testdate") === date), Seq(Row(date))) + // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs + // constant-folding, whereas earlier Spark versions would preserve the cast which prevented + // filter pushdown. + checkAnswer(df.filter("testdate = to_date('2015-07-03')"), Seq(Row(date))) + } + + test("filtering based on timestamp constants (regression test for #152)") { + val timestamp = TestUtils.toTimestamp(2015, zeroBasedMonth = 6, 1, 0, 0, 0, 1) + val df = sqlContext.sql("select testtimestamp from test_table") + + checkAnswer(df.filter(df("testtimestamp") === timestamp), Seq(Row(timestamp))) + // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs + // constant-folding, whereas earlier Spark versions would preserve the cast which prevented + // filter pushdown. + checkAnswer(df.filter("testtimestamp = '2015-07-01 00:00:00.001'"), Seq(Row(timestamp))) + } + + test("read special float values (regression test for #261)") { + val tableName = s"roundtrip_special_float_values_$randomSuffix" + try { + conn.createStatement().executeUpdate( + s"CREATE TABLE $tableName (x real)") + conn.createStatement().executeUpdate( + s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')") + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + // Due to #98, we use Double here instead of float: + checkAnswer( + read.option("dbtable", tableName).load(), + Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("read special double values (regression test for #261)") { + val tableName = s"roundtrip_special_double_values_$randomSuffix" + try { + conn.createStatement().executeUpdate( + s"CREATE TABLE $tableName (x double precision)") + conn.createStatement().executeUpdate( + s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')") + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer( + read.option("dbtable", tableName).load(), + Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("read records containing escaped characters") { + withTempRedshiftTable("records_with_escaped_characters") { tableName => + conn.createStatement().executeUpdate( + s"CREATE TABLE $tableName (x text)") + conn.createStatement().executeUpdate( + s"""INSERT INTO $tableName VALUES ('a\\nb'), ('\\\\'), ('"')""") + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer( + read.option("dbtable", tableName).load(), + Seq("a\nb", "\\", "\"").map(x => Row.apply(x))) + } + } + + test("read result of approximate count(distinct) query (#300)") { + val df = read + .option("query", s"select approximate count(distinct testbool) as c from $test_table") + .load() + assert(df.schema.fields(0).dataType === LongType) + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala new file mode 100755 index 0000000000000..a00fcf4d13823 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/RedshiftWriteSuite.scala @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.SQLException + +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * End-to-end tests of functionality which involves writing to Redshift via the connector. + */ +abstract class BaseRedshiftWriteSuite extends IntegrationSuiteBase { + + protected val tempformat: String + + override protected def write(df: DataFrame): DataFrameWriter[Row] = + super.write(df).option("tempformat", tempformat) + + test("roundtrip save and load") { + // This test can be simplified once #98 is fixed. + val tableName = s"roundtrip_save_and_load_$randomSuffix" + try { + write( + sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("roundtrip save and load with uppercase column names") { + testRoundtripSaveAndLoad( + s"roundtrip_write_and_read_with_uppercase_column_names_$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("A", IntegerType) :: Nil)), + expectedSchemaAfterLoad = Some(StructType(StructField("a", IntegerType) :: Nil))) + } + + test("save with column names that are reserved words") { + testRoundtripSaveAndLoad( + s"save_with_column_names_that_are_reserved_words_$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("table", IntegerType) :: Nil))) + } + + test("save with one empty partition (regression test for #96)") { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 2), + StructType(StructField("foo", IntegerType) :: Nil)) + assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array(Row(1)))) + testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) + } + + test("save with all empty partitions (regression test for #96)") { + val df = sqlContext.createDataFrame(sc.parallelize(Seq.empty[Row], 2), + StructType(StructField("foo", IntegerType) :: Nil)) + assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array.empty[Row])) + testRoundtripSaveAndLoad(s"save_with_all_empty_partitions_$randomSuffix", df) + // Now try overwriting that table. Although the new table is empty, it should still overwrite + // the existing table. + val df2 = df.withColumnRenamed("foo", "bar") + testRoundtripSaveAndLoad( + s"save_with_all_empty_partitions_$randomSuffix", df2, saveMode = SaveMode.Overwrite) + } + + test("informative error message when saving a table with string that is longer than max length") { + val tableName = s"error_message_when_string_too_long_$randomSuffix" + try { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), + StructType(StructField("A", StringType) :: Nil)) + val e = intercept[SQLException] { + write(df) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + } + assert(e.getMessage.contains("while loading data into Redshift")) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("full timestamp precision is preserved in loads (regression test for #214)") { + val timestamps = Seq( + TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1), + TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 10), + TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 100), + TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1000)) + testRoundtripSaveAndLoad( + s"full_timestamp_precision_is_preserved$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(timestamps.map(Row(_))), + StructType(StructField("ts", TimestampType) :: Nil)) + ) + } +} + +@ExtendedRedshiftTest +class AvroRedshiftWriteSuite extends BaseRedshiftWriteSuite { + override protected val tempformat: String = "AVRO" + + test("informative error message when saving with column names that contain spaces (#84)") { + intercept[IllegalArgumentException] { + testRoundtripSaveAndLoad( + s"error_when_saving_column_name_with_spaces_$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("column name with spaces", IntegerType) :: Nil))) + } + } +} + +@ExtendedRedshiftTest +class CSVRedshiftWriteSuite extends BaseRedshiftWriteSuite { + override protected val tempformat: String = "CSV" + + test("save with column names that contain spaces (#84)") { + testRoundtripSaveAndLoad( + s"save_with_column_names_that_contain_spaces_$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("column name with spaces", IntegerType) :: Nil))) + } +} + +@ExtendedRedshiftTest +class CSVGZIPRedshiftWriteSuite extends IntegrationSuiteBase { + // Note: we purposely don't inherit from BaseRedshiftWriteSuite because we're only interested in + // testing basic functionality of the GZIP code; the rest of the write path should be unaffected + // by compression here. + + override protected def write(df: DataFrame): DataFrameWriter[Row] = + super.write(df).option("tempformat", "CSV GZIP") + + test("roundtrip save and load") { + // This test can be simplified once #98 is fixed. + val tableName = s"roundtrip_save_and_load_$randomSuffix" + try { + write( + sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala new file mode 100755 index 0000000000000..c86081799e7ca --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/STSIntegrationSuite.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import com.amazonaws.auth.BasicAWSCredentials +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClient +import com.amazonaws.services.securitytoken.model.AssumeRoleRequest + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * Integration tests for accessing S3 using Amazon Security Token Service (STS) credentials. + */ +@ExtendedRedshiftTest +class STSIntegrationSuite extends IntegrationSuiteBase { + + private val STS_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN") + private var STS_ACCESS_KEY_ID: String = _ + private var STS_SECRET_ACCESS_KEY: String = _ + private var STS_SESSION_TOKEN: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val awsCredentials = new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) + val stsClient = new AWSSecurityTokenServiceClient(awsCredentials) + val assumeRoleRequest = new AssumeRoleRequest() + assumeRoleRequest.setDurationSeconds(900) // this is the minimum supported duration + assumeRoleRequest.setRoleArn(STS_ROLE_ARN) + assumeRoleRequest.setRoleSessionName(s"spark-$randomSuffix") + val creds = stsClient.assumeRole(assumeRoleRequest).getCredentials + STS_ACCESS_KEY_ID = creds.getAccessKeyId + STS_SECRET_ACCESS_KEY = creds.getSecretAccessKey + STS_SESSION_TOKEN = creds.getSessionToken + } + + test("roundtrip save and load") { + val tableName = s"roundtrip_save_and_load$randomSuffix" + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + try { + write(df) + .option("dbtable", tableName) + .option("forward_spark_s3_credentials", "false") + .option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID) + .option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY) + .option("temporary_aws_session_token", STS_SESSION_TOKEN) + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = read + .option("dbtable", tableName) + .option("forward_spark_s3_credentials", "false") + .option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID) + .option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY) + .option("temporary_aws_session_token", STS_SESSION_TOKEN) + .load() + assert(loadedDf.schema.length === 1) + assert(loadedDf.columns === Seq("a")) + checkAnswer(loadedDf, Seq(Row(1))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala new file mode 100755 index 0000000000000..0cdc8ba893171 --- /dev/null +++ b/external/redshift-integration-tests/src/test/scala/com/databricks/spark/redshift/SaveModeIntegrationSuite.scala @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.tags.ExtendedRedshiftTest + +/** + * End-to-end tests of [[SaveMode]] behavior. + */ +@ExtendedRedshiftTest +class SaveModeIntegrationSuite extends IntegrationSuiteBase { + test("SaveMode.Overwrite with schema-qualified table name (#97)") { + withTempRedshiftTable("overwrite_schema_qualified_table_name") { tableName => + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + // Ensure that the table exists: + write(df) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName")) + // Try overwriting that table while using the schema-qualified table name: + write(df) + .option("dbtable", s"PUBLIC.$tableName") + .mode(SaveMode.Overwrite) + .save() + } + } + + test("SaveMode.Overwrite with non-existent table") { + testRoundtripSaveAndLoad( + s"overwrite_non_existent_table$randomSuffix", + sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)), + saveMode = SaveMode.Overwrite) + } + + test("SaveMode.Overwrite with existing table") { + withTempRedshiftTable("overwrite_existing_table") { tableName => + // Create a table to overwrite + write(sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil))) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + + val overwritingDf = + sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema) + write(overwritingDf) + .option("dbtable", tableName) + .mode(SaveMode.Overwrite) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) + } + } + + // TODO:test overwrite that fails. + + test("Append SaveMode doesn't destroy existing data") { + withTempRedshiftTable("append_doesnt_destroy_existing_data") { tableName => + createTestDataInRedshift(tableName) + val extraData = Seq( + Row(2.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, + 24.toShort, "___|_123", null)) + + write(sqlContext.createDataFrame(sc.parallelize(extraData), TestUtils.testSchema)) + .option("dbtable", tableName) + .mode(SaveMode.Append) + .saveAsTable(tableName) + + checkAnswer( + sqlContext.sql(s"select * from $tableName"), + TestUtils.expectedData ++ extraData) + } + } + + test("Respect SaveMode.ErrorIfExists when table exists") { + withTempRedshiftTable("respect_savemode_error_if_exists") { tableName => + val rdd = sc.parallelize(TestUtils.expectedData) + val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema) + createTestDataInRedshift(tableName) // to ensure that the table already exists + + // Check that SaveMode.ErrorIfExists throws an exception + val e = intercept[Exception] { + write(df) + .option("dbtable", tableName) + .mode(SaveMode.ErrorIfExists) + .saveAsTable(tableName) + } + assert(e.getMessage.contains("exists")) + } + } + + test("Do nothing when table exists if SaveMode = Ignore") { + withTempRedshiftTable("do_nothing_when_savemode_ignore") { tableName => + val rdd = sc.parallelize(TestUtils.expectedData.drop(1)) + val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema) + createTestDataInRedshift(tableName) // to ensure that the table already exists + write(df) + .option("dbtable", tableName) + .mode(SaveMode.Ignore) + .saveAsTable(tableName) + + // Check that SaveMode.Ignore does nothing + checkAnswer( + sqlContext.sql(s"select * from $tableName"), + TestUtils.expectedData) + } + } +} diff --git a/external/redshift/pom.xml b/external/redshift/pom.xml new file mode 100644 index 0000000000000..6ea5e4318f5d3 --- /dev/null +++ b/external/redshift/pom.xml @@ -0,0 +1,145 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.1.0 + ../../pom.xml + + + com.databricks + spark-redshift_2.11 + + redshift + + jar + Spark Redshift + http://spark.apache.org/ + + + + hadoop-2.7 + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + test + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + com.databricks + spark-avro_${scala.binary.version} + ${project.version} + provided + + + + com.amazonaws + aws-java-sdk-core + 1.9.40 + provided + + + com.amazonaws + aws-java-sdk-s3 + 1.9.40 + provided + + + com.amazonaws + aws-java-sdk-sts + 1.9.40 + provided + + + com.eclipsesource.minimal-json + minimal-json + 0.9.4 + compile + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + test-jar + test + + + + com.amazonaws + redshift.jdbc4 + 1.1.7.1007 + jar + test + + + org.mockito + mockito-core + test + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala new file mode 100755 index 0000000000000..0fb6a307ee644 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/AWSCredentialsUtils.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, AWSSessionCredentials, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.databricks.spark.redshift.Parameters.MergedParameters +import org.apache.hadoop.conf.Configuration + +private[redshift] object AWSCredentialsUtils { + + /** + * Generates a credentials string for use in Redshift COPY and UNLOAD statements. + * Favors a configured `aws_iam_role` if available in the parameters. + */ + def getRedshiftCredentialsString( + params: MergedParameters, + sparkAwsCredentials: AWSCredentials): String = { + + def awsCredsToString(credentials: AWSCredentials): String = { + credentials match { + case creds: AWSSessionCredentials => + s"aws_access_key_id=${creds.getAWSAccessKeyId};" + + s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}" + case creds => + s"aws_access_key_id=${creds.getAWSAccessKeyId};" + + s"aws_secret_access_key=${creds.getAWSSecretKey}" + } + } + if (params.iamRole.isDefined) { + s"aws_iam_role=${params.iamRole.get}" + } else if (params.temporaryAWSCredentials.isDefined) { + awsCredsToString(params.temporaryAWSCredentials.get.getCredentials) + } else if (params.forwardSparkS3Credentials) { + awsCredsToString(sparkAwsCredentials) + } else { + throw new IllegalStateException("No Redshift S3 authentication mechanism was specified") + } + } + + def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = { + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = credentials + override def refresh(): Unit = {} + } + } + + def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentialsProvider = { + params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration)) + } + + private def loadFromURI( + tempPath: String, + hadoopConfiguration: Configuration): AWSCredentialsProvider = { + // scalastyle:off + // A good reference on Hadoop's configuration loading / precedence is + // https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md + // scalastyle:on + val uri = new URI(tempPath) + val uriScheme = uri.getScheme + + uriScheme match { + case "s3" | "s3n" | "s3a" => + // This matches what S3A does, with one exception: we don't support anonymous credentials. + // First, try to parse from URI: + Option(uri.getUserInfo).flatMap { userInfo => + if (userInfo.contains(":")) { + val Array(accessKey, secretKey) = userInfo.split(":") + Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))) + } else { + None + } + }.orElse { + // Next, try to read from configuration + val accessKeyConfig = if (uriScheme == "s3a") "access.key" else "awsAccessKeyId" + val secretKeyConfig = if (uriScheme == "s3a") "secret.key" else "awsSecretAccessKey" + + val accessKey = hadoopConfiguration.get(s"fs.$uriScheme.$accessKeyConfig", null) + val secretKey = hadoopConfiguration.get(s"fs.$uriScheme.$secretKeyConfig", null) + if (accessKey != null && secretKey != null) { + Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))) + } else { + None + } + }.getOrElse { + // Finally, fall back on the instance profile provider + new DefaultAWSCredentialsProviderChain() + } + case other => + throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a") + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala new file mode 100755 index 0000000000000..a83411db48c85 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Conversions.scala @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.Timestamp +import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat} +import java.util.Locale + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ + +/** + * Data type conversions for Redshift unloaded data + */ +private[redshift] object Conversions { + + /** + * Parse a boolean using Redshift's UNLOAD bool syntax + */ + private def parseBoolean(s: String): Boolean = { + if (s == "t") true + else if (s == "f") false + else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'") + } + + /** + * Formatter for writing decimals unloaded from Redshift. + * + * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this + * DecimalFormat across threads. + */ + def createRedshiftDecimalFormat(): DecimalFormat = { + val format = new DecimalFormat() + format.setParseBigDecimal(true) + format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US)) + format + } + + /** + * Formatter for parsing strings exported from Redshift DATE columns. + * + * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this + * SimpleDateFormat across threads. + */ + def createRedshiftDateFormat(): SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd") + + /** + * Formatter for formatting timestamps for insertion into Redshift TIMESTAMP columns. + * + * This formatter should not be used to parse timestamps returned from Redshift UNLOAD commands; + * instead, use [[Timestamp.valueOf()]]. + * + * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this + * SimpleDateFormat across threads. + */ + def createRedshiftTimestampFormat(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS") + } + + /** + * Return a function that will convert arrays of strings conforming to the given schema to Rows. + * + * Note that instances of this function are NOT thread-safe. + */ + def createRowConverter(schema: StructType): Array[String] => InternalRow = { + val dateFormat = createRedshiftDateFormat() + val decimalFormat = createRedshiftDecimalFormat() + val conversionFunctions: Array[String => Any] = schema.fields.map { field => + field.dataType match { + case ByteType => (data: String) => java.lang.Byte.parseByte(data) + case BooleanType => (data: String) => parseBoolean(data) + case DateType => (data: String) => new java.sql.Date(dateFormat.parse(data).getTime) + case DoubleType => (data: String) => data match { + case "nan" => Double.NaN + case "inf" => Double.PositiveInfinity + case "-inf" => Double.NegativeInfinity + case _ => java.lang.Double.parseDouble(data) + } + case FloatType => (data: String) => data match { + case "nan" => Float.NaN + case "inf" => Float.PositiveInfinity + case "-inf" => Float.NegativeInfinity + case _ => java.lang.Float.parseFloat(data) + } + case dt: DecimalType => + (data: String) => decimalFormat.parse(data).asInstanceOf[java.math.BigDecimal] + case IntegerType => (data: String) => java.lang.Integer.parseInt(data) + case LongType => (data: String) => java.lang.Long.parseLong(data) + case ShortType => (data: String) => java.lang.Short.parseShort(data) + case StringType => (data: String) => data + case TimestampType => (data: String) => Timestamp.valueOf(data) + case _ => (data: String) => data + } + } + // As a performance optimization, re-use the same mutable row / array: + val converted: Array[Any] = Array.fill(schema.length)(null) + val externalRow = new GenericRow(converted) + val encoder = RowEncoder(schema) + (inputRow: Array[String]) => { + var i = 0 + while (i < schema.length) { + val data = inputRow(i) + converted(i) = if (data == null || data.isEmpty) null else conversionFunctions(i)(data) + i += 1 + } + encoder.toRow(externalRow) + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala new file mode 100755 index 0000000000000..258255676a860 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.services.s3.AmazonS3Client +import org.slf4j.LoggerFactory + +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider} +import org.apache.spark.sql.types.StructType + +/** + * Redshift Source implementation for Spark SQL + */ +class DefaultSource( + jdbcWrapper: JDBCWrapper, + s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + private val log = LoggerFactory.getLogger(getClass) + + /** + * Default constructor required by Data Source API + */ + def this() = this(DefaultJDBCWrapper, awsCredentials => new AmazonS3Client(awsCredentials)) + + /** + * Create a new RedshiftRelation instance using parameters from Spark SQL DDL. Resolves the schema + * using JDBC connection over provided URL, which must contain credentials. + */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val params = Parameters.mergeParameters(parameters) + RedshiftRelation(jdbcWrapper, s3ClientFactory, params, None)(sqlContext) + } + + /** + * Load a RedshiftRelation using user-provided schema, so no inference over JDBC will be used. + */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + val params = Parameters.mergeParameters(parameters) + RedshiftRelation(jdbcWrapper, s3ClientFactory, params, Some(schema))(sqlContext) + } + + /** + * Creates a Relation instance by first writing the contents of the given DataFrame to Redshift + */ + override def createRelation( + sqlContext: SQLContext, + saveMode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val params = Parameters.mergeParameters(parameters) + val table = params.table.getOrElse { + throw new IllegalArgumentException( + "For save operations you must specify a Redshift table name with the 'dbtable' parameter") + } + + def tableExists: Boolean = { + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) + try { + jdbcWrapper.tableExists(conn, table.toString) + } finally { + conn.close() + } + } + + val (doSave, dropExisting) = saveMode match { + case SaveMode.Append => (true, false) + case SaveMode.Overwrite => (true, true) + case SaveMode.ErrorIfExists => + if (tableExists) { + sys.error(s"Table $table already exists! (SaveMode is set to ErrorIfExists)") + } else { + (true, false) + } + case SaveMode.Ignore => + if (tableExists) { + log.info(s"Table $table already exists -- ignoring save request.") + (false, false) + } else { + (true, false) + } + } + + if (doSave) { + val updatedParams = parameters.updated("overwrite", dropExisting.toString) + new RedshiftWriter(jdbcWrapper, s3ClientFactory).saveToRedshift( + sqlContext, data, saveMode, Parameters.mergeParameters(updatedParams)) + } + + createRelation(sqlContext, parameters) + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala new file mode 100755 index 0000000000000..786b3f6257f6e --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/FilterPushdown.scala @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * Helper methods for pushing filters into Redshift queries. + */ +private[redshift] object FilterPushdown { + /** + * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no + * condition will be added to the WHERE clause. If none of the filters can be pushed down then + * an empty string will be returned. + * + * @param schema the schema of the table being queried + * @param filters an array of filters, the conjunction of which is the filter condition for the + * scan. + */ + def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = { + val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ") + if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions + } + + /** + * Attempt to convert the given filter into a SQL expression. Returns None if the expression + * could not be converted. + */ + def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = { + def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = { + getTypeForAttribute(schema, attr).map { dataType => + val sqlEscapedValue: String = dataType match { + case StringType => s"\\'${value.toString.replace("'", "\\'\\'")}\\'" + case DateType => s"\\'${value.asInstanceOf[Date]}\\'" + case TimestampType => s"\\'${value.asInstanceOf[Timestamp]}\\'" + case _ => value.toString + } + s""""$attr" $comparisonOp $sqlEscapedValue""" + } + } + + filter match { + case EqualTo(attr, value) => buildComparison(attr, value, "=") + case LessThan(attr, value) => buildComparison(attr, value, "<") + case GreaterThan(attr, value) => buildComparison(attr, value, ">") + case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=") + case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=") + case IsNotNull(attr) => + getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NOT NULL""") + case IsNull(attr) => + getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NULL""") + case _ => None + } + } + + /** + * Use the given schema to look up the attribute's data type. Returns None if the attribute could + * not be resolved. + */ + private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = { + if (schema.fieldNames.contains(attribute)) { + Some(schema(attribute).dataType) + } else { + None + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala new file mode 100755 index 0000000000000..8789b5c8ae2d6 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -0,0 +1,282 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import com.amazonaws.auth.{AWSCredentialsProvider, BasicSessionCredentials} + +/** + * All user-specifiable parameters for spark-redshift, along with their validation rules and + * defaults. + */ +private[redshift] object Parameters { + + val DEFAULT_PARAMETERS: Map[String, String] = Map( + // Notes: + // * tempdir, dbtable and url have no default and they *must* be provided + // * sortkeyspec has no default, but is optional + // * distkey has no default, but is optional unless using diststyle KEY + // * jdbcdriver has no default, but is optional + + "forward_spark_s3_credentials" -> "false", + "tempformat" -> "AVRO", + "csvnullstring" -> "@NULL@", + "overwrite" -> "false", + "diststyle" -> "EVEN", + "usestagingtable" -> "true", + "preactions" -> ";", + "postactions" -> ";" + ) + + val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP") + + /** + * Merge user parameters with the defaults, preferring user parameters if specified + */ + def mergeParameters(userParameters: Map[String, String]): MergedParameters = { + if (!userParameters.contains("tempdir")) { + throw new IllegalArgumentException("'tempdir' is required for all Redshift loads and saves") + } + if (userParameters.contains("tempformat") && + !VALID_TEMP_FORMATS.contains(userParameters("tempformat").toUpperCase)) { + throw new IllegalArgumentException( + s"""Invalid temp format: ${userParameters("tempformat")}; """ + + s"valid formats are: ${VALID_TEMP_FORMATS.mkString(", ")}") + } + if (!userParameters.contains("url")) { + throw new IllegalArgumentException("A JDBC URL must be provided with 'url' parameter") + } + if (!userParameters.contains("dbtable") && !userParameters.contains("query")) { + throw new IllegalArgumentException( + "You must specify a Redshift table name with the 'dbtable' parameter or a query with the " + + "'query' parameter.") + } + if (userParameters.contains("dbtable") && userParameters.contains("query")) { + throw new IllegalArgumentException( + "You cannot specify both the 'dbtable' and 'query' parameters at the same time.") + } + val credsInURL = userParameters.get("url") + .filter(url => url.contains("user=") || url.contains("password=")) + if (userParameters.contains("user") || userParameters.contains("password")) { + if (credsInURL.isDefined) { + throw new IllegalArgumentException( + "You cannot specify credentials in both the URL and as user/password options") + } + } else if (credsInURL.isEmpty) { + throw new IllegalArgumentException( + "You must specify credentials in either the URL or as user/password options") + } + + MergedParameters(DEFAULT_PARAMETERS ++ userParameters) + } + + /** + * Adds validators and accessors to string map + */ + case class MergedParameters(parameters: Map[String, String]) { + + require(temporaryAWSCredentials.isDefined || iamRole.isDefined || forwardSparkS3Credentials, + "You must specify a method for authenticating Redshift's connection to S3 (aws_iam_role," + + " forward_spark_s3_credentials, or temporary_aws_*. For a discussion of the differences" + + " between these options, please see the README.") + + require(Seq( + temporaryAWSCredentials.isDefined, + iamRole.isDefined, + forwardSparkS3Credentials).count(_ == true) == 1, + "The aws_iam_role, forward_spark_s3_credentials, and temporary_aws_*. options are " + + "mutually-exclusive; please specify only one.") + + /** + * A root directory to be used for intermediate data exchange, expected to be on S3, or + * somewhere that can be written to and read from by Redshift. Make sure that AWS credentials + * are available for S3. + */ + def rootTempDir: String = parameters("tempdir") + + /** + * The format in which to save temporary files in S3. Defaults to "AVRO"; the other allowed + * values are "CSV" and "CSV GZIP" for CSV and gzipped CSV, respectively. + */ + def tempFormat: String = parameters("tempformat").toUpperCase + + /** + * The String value to write for nulls when using CSV. + * This should be a value which does not appear in your actual data. + */ + def nullString: String = parameters("csvnullstring") + + /** + * Creates a per-query subdirectory in the [[rootTempDir]], with a random UUID. + */ + def createPerQueryTempDir(): String = Utils.makeTempPath(rootTempDir) + + /** + * The Redshift table to be used as the target when loading or writing data. + */ + def table: Option[TableName] = parameters.get("dbtable").map(_.trim).flatMap { dbtable => + // We technically allow queries to be passed using `dbtable` as long as they are wrapped + // in parentheses. Valid SQL identifiers may contain parentheses but cannot begin with them, + // so there is no ambiguity in ignoring subqeries here and leaving their handling up to + // the `query` function defined below. + if (dbtable.startsWith("(") && dbtable.endsWith(")")) { + None + } else { + Some(TableName.parseFromEscaped(dbtable)) + } + } + + /** + * The Redshift query to be used as the target when loading data. + */ + def query: Option[String] = parameters.get("query").orElse { + parameters.get("dbtable") + .map(_.trim) + .filter(t => t.startsWith("(") && t.endsWith(")")) + .map(t => t.drop(1).dropRight(1)) + } + + /** + * User and password to be used to authenticate to Redshift + */ + def credentials: Option[(String, String)] = { + for ( + user <- parameters.get("user"); + password <- parameters.get("password") + ) yield (user, password) + } + + /** + * A JDBC URL, of the format: + * + * jdbc:subprotocol://host:port/database?user=username&password=password + * + * Where: + * - subprotocol can be postgresql or redshift, depending on which JDBC driver you have loaded. + * Note however that one Redshift-compatible driver must be on the classpath and match this + * URL. + * - host and port should point to the Redshift master node, so security groups and/or VPC will + * need to be configured to allow access from the Spark driver + * - database identifies a Redshift database name + * - user and password are credentials to access the database, which must be embedded in this + * URL for JDBC + */ + def jdbcUrl: String = parameters("url") + + /** + * The JDBC driver class name. This is used to make sure the driver is registered before + * connecting over JDBC. + */ + def jdbcDriver: Option[String] = parameters.get("jdbcdriver") + + /** + * Set the Redshift table distribution style, which can be one of: EVEN, KEY or ALL. If you set + * it to KEY, you'll also need to use the distkey parameter to set the distribution key. + * + * Default is EVEN. + */ + def distStyle: Option[String] = parameters.get("diststyle") + + /** + * The name of a column in the table to use as the distribution key when using DISTSTYLE KEY. + * Not set by default, as default DISTSTYLE is EVEN. + */ + def distKey: Option[String] = parameters.get("distkey") + + /** + * A full Redshift SORTKEY specification. For full information, see latest Redshift docs: + * http://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_TABLE_NEW.html + * + * Examples: + * SORTKEY (my_sort_column) + * COMPOUND SORTKEY (sort_col1, sort_col2) + * INTERLEAVED SORTKEY (sort_col1, sort_col2) + * + * Not set by default - table will be unsorted. + * + * Note: appending data to a table with a sort key only makes sense if you know that the data + * being added will be after the data already in the table according to the sort order. Redshift + * does not support random inserts according to sort order, so performance will degrade if you + * try this. + */ + def sortKeySpec: Option[String] = parameters.get("sortkeyspec") + + /** + * DEPRECATED: see PR #157. + * + * When true, data is always loaded into a new temporary table when performing an overwrite. + * This is to ensure that the whole load process succeeds before dropping any data from + * Redshift, which can be useful if, in the event of failures, stale data is better than no data + * for your systems. + * + * Defaults to true. + */ + def useStagingTable: Boolean = parameters("usestagingtable").toBoolean + + /** + * Extra options to append to the Redshift COPY command (e.g. "MAXERROR 100"). + */ + def extraCopyOptions: String = parameters.get("extracopyoptions").getOrElse("") + + /** + * Description of the table, set using the SQL COMMENT command. + */ + def description: Option[String] = parameters.get("description") + + /** + * List of semi-colon separated SQL statements to run before write operations. + * This can be useful for running DELETE operations to clean up data + * + * If the action string contains %s, the table name will be substituted in, in case a staging + * table is being used. + * + * Defaults to empty. + */ + def preActions: Array[String] = parameters("preactions").split(";") + + /** + * List of semi-colon separated SQL statements to run after successful write operations. + * This can be useful for running GRANT operations to make your new tables readable to other + * users and groups. + * + * If the action string contains %s, the table name will be substituted in, in case a staging + * table is being used. + * + * Defaults to empty. + */ + def postActions: Array[String] = parameters("postactions").split(";") + + /** + * The IAM role that Redshift should assume for COPY/UNLOAD operations. + */ + def iamRole: Option[String] = parameters.get("aws_iam_role") + + /** + * If true then this library will automatically discover the credentials that Spark is + * using to connect to S3 and will forward those credentials to Redshift over JDBC. + */ + def forwardSparkS3Credentials: Boolean = parameters("forward_spark_s3_credentials").toBoolean + + /** + * Temporary AWS credentials which are passed to Redshift. These only need to be supplied by + * the user when Hadoop is configured to authenticate to S3 via IAM roles assigned to EC2 + * instances. + */ + def temporaryAWSCredentials: Option[AWSCredentialsProvider] = { + for ( + accessKey <- parameters.get("temporary_aws_access_key_id"); + secretAccessKey <- parameters.get("temporary_aws_secret_access_key"); + sessionToken <- parameters.get("temporary_aws_session_token") + ) yield { + AWSCredentialsUtils.staticCredentialsProvider( + new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)) + } + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala new file mode 100755 index 0000000000000..acfb9140fde53 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RecordReaderIterator.scala @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io.Closeable + +import org.apache.hadoop.mapreduce.RecordReader + +/** + * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned. + * + * This is copied from Apache Spark and is inlined here to avoid depending on Spark internals + * in this external library. + */ +private[redshift] class RecordReaderIterator[T]( + private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { + private[this] var havePair = false + private[this] var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !rowReader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } + havePair = !finished + } + !finished + } + + override def next(): T = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + rowReader.getCurrentValue + } + + override def close(): Unit = { + if (rowReader != null) { + try { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues + // when reading compressed input. + rowReader.close() + } finally { + rowReader = null + } + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala new file mode 100755 index 0000000000000..a45daadcf01f2 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +/** + * Internal data source used for reading Redshift UNLOAD files. + * + * This is not intended for public consumption / use outside of this package and therefore + * no API stability is guaranteed. + */ +private[redshift] class RedshiftFileFormat extends FileFormat { + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + // Schema is provided by caller. + None + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException(s"prepareWrite is not supported for $this") + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + // Our custom InputFormat handles split records properly + true + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + require(partitionSchema.isEmpty) + require(filters.isEmpty) + require(dataSchema == requiredSchema) + + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), + file.start, + file.length, + // TODO: Implement Locality + Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val reader = new RedshiftRecordReader + reader.initialize(fileSplit, hadoopAttemptContext) + val iter = new RecordReaderIterator[Array[String]](reader) + // Ensure that the record reader is closed upon task completion. It will ordinarily + // be closed once it is completely iterated, but this is necessary to guard against + // resource leaks in case the task fails or is interrupted. + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + val converter = Conversions.createRowConverter(requiredSchema) + iter.map(converter) + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala new file mode 100755 index 0000000000000..94829c028e520 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftInputFormat.scala @@ -0,0 +1,252 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io.{BufferedInputStream, IOException} +import java.lang.{Long => JavaLong} +import java.nio.charset.Charset + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} + +/** + * Input format for text records saved with in-record delimiter and newline characters escaped. + * + * For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|` + * should be the following: + * {{{ + * a\\\n|\\|b\\\\\n + * }}}, + * where the in-record `|`, `\r`, `\n`, and `\\` characters are escaped by `\\`. + * Users can configure the delimiter via [[RedshiftInputFormat$#KEY_DELIMITER]]. + * Its default value [[RedshiftInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD + * with the ESCAPE option: + * {{{ + * UNLOAD ('select_statement') + * TO 's3://object_path_prefix' + * ESCAPE + * }}} + * + * @see org.apache.spark.SparkContext#newAPIHadoopFile + */ +class RedshiftInputFormat extends FileInputFormat[JavaLong, Array[String]] { + + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[JavaLong, Array[String]] = { + new RedshiftRecordReader + } +} + +object RedshiftInputFormat { + + /** configuration key for delimiter */ + val KEY_DELIMITER = "redshift.delimiter" + /** default delimiter */ + val DEFAULT_DELIMITER = '|' + + /** Gets the delimiter char from conf or the default. */ + private[redshift] def getDelimiterOrDefault(conf: Configuration): Char = { + val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString) + if (c.length != 1) { + throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.") + } else { + c.charAt(0) + } + } +} + +private[redshift] class RedshiftRecordReader extends RecordReader[JavaLong, Array[String]] { + + private var reader: BufferedInputStream = _ + + private var key: JavaLong = _ + private var value: Array[String] = _ + + private var start: Long = _ + private var end: Long = _ + private var cur: Long = _ + + private var eof: Boolean = false + + private var delimiter: Byte = _ + @inline private[this] final val escapeChar: Byte = '\\' + @inline private[this] final val lineFeed: Byte = '\n' + @inline private[this] final val carriageReturn: Byte = '\r' + + @inline private[this] final val defaultBufferSize = 1024 * 1024 + + private[this] val chars = ArrayBuffer.empty[Byte] + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = { + val split = inputSplit.asInstanceOf[FileSplit] + val file = split.getPath + val conf: Configuration = context.getConfiguration + delimiter = RedshiftInputFormat.getDelimiterOrDefault(conf).asInstanceOf[Byte] + require(delimiter != escapeChar, + s"The delimiter and the escape char cannot be the same but found $delimiter.") + require(delimiter != lineFeed, "The delimiter cannot be the lineFeed character.") + require(delimiter != carriageReturn, "The delimiter cannot be the carriage return.") + val compressionCodecs = new CompressionCodecFactory(conf) + val codec = compressionCodecs.getCodec(file) + if (codec != null) { + throw new IOException(s"Do not support compressed files but found $file.") + } + val fs = file.getFileSystem(conf) + val size = fs.getFileStatus(file).getLen + start = findNext(fs, file, size, split.getStart) + end = findNext(fs, file, size, split.getStart + split.getLength) + cur = start + val in = fs.open(file) + if (cur > 0L) { + in.seek(cur - 1L) + in.read() + } + reader = new BufferedInputStream(in, defaultBufferSize) + } + + override def getProgress: Float = { + if (start >= end) { + 1.0f + } else { + math.min((cur - start).toFloat / (end - start), 1.0f) + } + } + + override def nextKeyValue(): Boolean = { + if (cur < end && !eof) { + key = cur + value = nextValue() + true + } else { + key = null + value = null + false + } + } + + override def getCurrentValue: Array[String] = value + + override def getCurrentKey: JavaLong = key + + override def close(): Unit = { + if (reader != null) { + reader.close() + } + } + + /** + * Finds the start of the next record. + * Because we don't know whether the first char is escaped or not, we need to first find a + * position that is not escaped. + * + * @param fs file system + * @param file file path + * @param size file size + * @param offset start offset + * @return the start position of the next record + */ + private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = { + if (offset == 0L) { + return 0L + } else if (offset >= size) { + return size + } + val in = fs.open(file) + var pos = offset + in.seek(pos) + val bis = new BufferedInputStream(in, defaultBufferSize) + // Find the first unescaped char. + var escaped = true + var thisEof = false + while (escaped && !thisEof) { + val v = bis.read() + if (v < 0) { + thisEof = true + } else { + pos += 1 + if (v != escapeChar) { + escaped = false + } + } + } + // Find the next unescaped line feed. + var endOfRecord = false + while ((escaped || !endOfRecord) && !thisEof) { + val v = bis.read() + if (v < 0) { + thisEof = true + } else { + pos += 1 + if (v == escapeChar) { + escaped = true + } else { + if (!escaped) { + endOfRecord = v == lineFeed + } else { + escaped = false + } + } + } + } + in.close() + pos + } + + private def nextValue(): Array[String] = { + val fields = ArrayBuffer.empty[String] + var escaped = false + var endOfRecord = false + while (!endOfRecord && !eof) { + var endOfField = false + chars.clear() + while (!endOfField && !endOfRecord && !eof) { + val v = reader.read() + if (v < 0) { + eof = true + } else { + cur += 1L + val c = v.asInstanceOf[Byte] + if (escaped) { + if (c != escapeChar && c != delimiter && c != lineFeed && c != carriageReturn) { + throw new IllegalStateException( + s"Found `$c` (ASCII $v) after $escapeChar.") + } + chars.append(c) + escaped = false + } else { + if (c == escapeChar) { + escaped = true + } else if (c == delimiter) { + endOfField = true + } else if (c == lineFeed) { + endOfRecord = true + } else { + // also copy carriage return + chars.append(c) + } + } + } + } + // TODO: charset? + fields.append(new String(chars.toArray, Charset.forName("UTF-8"))) + } + if (escaped) { + throw new IllegalStateException(s"Found hanging escape char.") + } + fields.toArray + } +} + diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala new file mode 100755 index 0000000000000..c110176859d9e --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -0,0 +1,347 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.util.Properties +import java.util.concurrent.{Executors, ThreadFactory} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Try +import scala.util.control.NonFatal + +import org.slf4j.LoggerFactory + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry +import org.apache.spark.sql.types._ + +/** + * Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with + * minor modifications for Redshift-specific features and limitations. + */ +private[redshift] class JDBCWrapper { + + private val log = LoggerFactory.getLogger(getClass) + + private val ec: ExecutionContext = { + val threadFactory = new ThreadFactory { + private[this] val count = new AtomicInteger() + override def newThread(r: Runnable) = { + val thread = new Thread(r) + thread.setName(s"spark-redshift-JDBCWrapper-${count.incrementAndGet}") + thread.setDaemon(true) + thread + } + } + ExecutionContext.fromExecutorService(Executors.newCachedThreadPool(threadFactory)) + } + + /** + * Given a JDBC subprotocol, returns the name of the appropriate driver class to use. + * + * If the user has explicitly specified a driver class in their configuration then that class will + * be used. Otherwise, we will attempt to load the correct driver class based on + * the JDBC subprotocol. + * + * @param jdbcSubprotocol 'redshift' or 'postgresql' + * @param userProvidedDriverClass an optional user-provided explicit driver class name + * @return the driver class + */ + private def getDriverClass( + jdbcSubprotocol: String, + userProvidedDriverClass: Option[String]): String = { + userProvidedDriverClass.getOrElse { + jdbcSubprotocol match { + case "redshift" => + try { + Utils.classForName("com.amazon.redshift.jdbc42.Driver").getName + } catch { + case _: ClassNotFoundException => + try { + Utils.classForName("com.amazon.redshift.jdbc41.Driver").getName + } catch { + case _: ClassNotFoundException => + try { + Utils.classForName("com.amazon.redshift.jdbc4.Driver").getName + } catch { + case e: ClassNotFoundException => + throw new ClassNotFoundException( + "Could not load an Amazon Redshift JDBC driver; see the README for " + + "instructions on downloading and configuring the official Amazon driver.", + e + ) + } + } + } + case "postgresql" => "org.postgresql.Driver" + case other => throw new IllegalArgumentException(s"Unsupported JDBC protocol: '$other'") + } + } + } + + /** + * Execute the given SQL statement while supporting interruption. + * If InterruptedException is caught, then the statement will be cancelled if it is running. + * + * @return true if the first result is a ResultSet + * object; false if the first result is an update + * count or there is no result + */ + def executeInterruptibly(statement: PreparedStatement): Boolean = { + executeInterruptibly(statement, _.execute()) + } + + /** + * Execute the given SQL statement while supporting interruption. + * If InterruptedException is caught, then the statement will be cancelled if it is running. + * + * @return a ResultSet object that contains the data produced by the + * query; never null + */ + def executeQueryInterruptibly(statement: PreparedStatement): ResultSet = { + executeInterruptibly(statement, _.executeQuery()) + } + + private def executeInterruptibly[T]( + statement: PreparedStatement, + op: PreparedStatement => T): T = { + try { + val future = Future[T](op(statement))(ec) + try { + // scalastyle:off awaitresult + Await.result(future, Duration.Inf) + // scalastyle:on awaitresult + } catch { + case e: SQLException => + // Wrap and re-throw so that this thread's stacktrace appears to the user. + throw new SQLException("Exception thrown in awaitResult: ", e) + case NonFatal(t) => + // Wrap and re-throw so that this thread's stacktrace appears to the user. + throw new Exception("Exception thrown in awaitResult: ", t) + } + } catch { + case e: InterruptedException => + try { + statement.cancel() + throw e + } catch { + case s: SQLException => + log.error("Exception occurred while cancelling query", s) + throw e + } + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst + * schema. + * + * @param conn A JDBC connection to the database. + * @param table The table name of the desired table. This may also be a + * SQL query wrapped in parentheses. + * + * @return A StructType giving the table's Catalyst schema. + * @throws SQLException if the table specification is garbage. + * @throws SQLException if the table contains an unsupported type. + */ + def resolveTable(conn: Connection, table: String): StructType = { + // It's important to leave the `LIMIT 1` clause in order to limit the work of the query in case + // the underlying JDBC driver implementation implements PreparedStatement.getMetaData() by + // executing the query. It looks like the standard Redshift and Postgres JDBC drivers don't do + // this but we leave the LIMIT condition here as a safety-net to guard against perf regressions. + val ps = conn.prepareStatement(s"SELECT * FROM $table LIMIT 1") + try { + val rsmd = executeInterruptibly(ps, _.getMetaData) + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned) + fields(i) = StructField(columnName, columnType, nullable) + i = i + 1 + } + new StructType(fields) + } finally { + ps.close() + } + } + + /** + * Given a driver string and a JDBC url, load the specified driver and return a DB connection. + * + * @param userProvidedDriverClass the class name of the JDBC driver for the given url. If this + * is None then `spark-redshift` will attempt to automatically + * discover the appropriate driver class. + * @param url the JDBC url to connect to. + */ + def getConnector( + userProvidedDriverClass: Option[String], + url: String, + credentials: Option[(String, String)]) : Connection = { + val subprotocol = url.stripPrefix("jdbc:").split(":")(0) + val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass) + DriverRegistry.register(driverClass) + val driverWrapperClass: Class[_] = if (SPARK_VERSION.startsWith("1.4")) { + Utils.classForName("org.apache.spark.sql.jdbc.package$DriverWrapper") + } else { // Spark 1.5.0+ + Utils.classForName("org.apache.spark.sql.execution.datasources.jdbc.DriverWrapper") + } + def getWrapped(d: Driver): Driver = { + require(driverWrapperClass.isAssignableFrom(d.getClass)) + driverWrapperClass.getDeclaredMethod("wrapped").invoke(d).asInstanceOf[Driver] + } + // Note that we purposely don't call DriverManager.getConnection() here: we want to ensure + // that an explicitly-specified user-provided driver class can take precedence over the default + // class, but DriverManager.getConnection() might return a according to a different precedence. + // At the same time, we don't want to create a driver-per-connection, so we use the + // DriverManager's driver instances to handle that singleton logic for us. + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d if driverWrapperClass.isAssignableFrom(d.getClass) + && getWrapped(d).getClass.getCanonicalName == driverClass => d + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass") + } + val properties = new Properties() + credentials.foreach { case(user, password) => + properties.setProperty("user", user) + properties.setProperty("password", password) + } + driver.connect(url, properties) + } + + /** + * Compute the SQL schema string for the given Spark SQL Schema. + */ + def schemaString(schema: StructType): String = { + val sb = new StringBuilder() + schema.fields.foreach { field => { + val name = field.name + val typ: String = if (field.metadata.contains("redshift_type")) { + field.metadata.getString("redshift_type") + } else { + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "SMALLINT" // Redshift does not support the BYTE type. + case BooleanType => "BOOLEAN" + case StringType => + if (field.metadata.contains("maxlength")) { + s"VARCHAR(${field.metadata.getLong("maxlength")})" + } else { + "TEXT" + } + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + } + } + + val nullable = if (field.nullable) "" else "NOT NULL" + val encoding = if (field.metadata.contains("encoding")) { + s"ENCODE ${field.metadata.getString("encoding")}" + } else { + "" + } + sb.append(s""", "${name.replace("\"", "\\\"")}" $typ $nullable $encoding""".trim) + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try { + val stmt = conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1") + executeInterruptibly(stmt, _.getMetaData).getColumnCount + }.isSuccess + } + + /** + * Maps a JDBC type to a Catalyst type. + * + * @param sqlType - A field of java.sql.Types + * @return The Catalyst type corresponding to sqlType. + */ + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { + // TODO: cleanup types which are irrelevant for Redshift. + val answer = sqlType match { + // scalastyle:off + case java.sql.Types.ARRAY => null + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } + case java.sql.Types.BINARY => BinaryType + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks + case java.sql.Types.BLOB => BinaryType + case java.sql.Types.BOOLEAN => BooleanType + case java.sql.Types.CHAR => StringType + case java.sql.Types.CLOB => StringType + case java.sql.Types.DATALINK => null + case java.sql.Types.DATE => DateType + case java.sql.Types.DECIMAL + if precision != 0 || scale != 0 => DecimalType(precision, scale) + case java.sql.Types.DECIMAL => DecimalType(38, 18) // Spark 1.5.0 default + case java.sql.Types.DISTINCT => null + case java.sql.Types.DOUBLE => DoubleType + case java.sql.Types.FLOAT => FloatType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } + case java.sql.Types.JAVA_OBJECT => null + case java.sql.Types.LONGNVARCHAR => StringType + case java.sql.Types.LONGVARBINARY => BinaryType + case java.sql.Types.LONGVARCHAR => StringType + case java.sql.Types.NCHAR => StringType + case java.sql.Types.NCLOB => StringType + case java.sql.Types.NULL => null + case java.sql.Types.NUMERIC + if precision != 0 || scale != 0 => DecimalType(precision, scale) + case java.sql.Types.NUMERIC => DecimalType(38, 18) // Spark 1.5.0 default + case java.sql.Types.NVARCHAR => StringType + case java.sql.Types.OTHER => null + case java.sql.Types.REAL => DoubleType + case java.sql.Types.REF => StringType + case java.sql.Types.ROWID => LongType + case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SQLXML => StringType + case java.sql.Types.STRUCT => StringType + case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.VARBINARY => BinaryType + case java.sql.Types.VARCHAR => StringType + case _ => null + // scalastyle:on + } + + if (answer == null) throw new SQLException("Unsupported type " + sqlType) + answer + } +} + +private[redshift] object DefaultJDBCWrapper extends JDBCWrapper diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala new file mode 100755 index 0000000000000..fd24546d452b5 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -0,0 +1,195 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io.InputStreamReader +import java.net.URI + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.services.s3.AmazonS3Client +import com.databricks.spark.redshift.Parameters.MergedParameters +import com.eclipsesource.json.Json +import org.slf4j.LoggerFactory + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * Data Source API implementation for Amazon Redshift database tables + */ +private[redshift] case class RedshiftRelation( + jdbcWrapper: JDBCWrapper, + s3ClientFactory: AWSCredentialsProvider => AmazonS3Client, + params: MergedParameters, + userSchema: Option[StructType]) + (@transient val sqlContext: SQLContext) + extends BaseRelation + with PrunedFilteredScan + with InsertableRelation { + + private val log = LoggerFactory.getLogger(getClass) + + if (sqlContext != null) { + Utils.assertThatFileSystemIsNotS3BlockFileSystem( + new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration) + } + + private val tableNameOrSubquery = + params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get + + override lazy val schema: StructType = { + userSchema.getOrElse { + val tableNameOrSubquery = + params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) + try { + jdbcWrapper.resolveTable(conn, tableNameOrSubquery) + } finally { + conn.close() + } + } + } + + override def toString: String = s"RedshiftRelation($tableNameOrSubquery)" + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val saveMode = if (overwrite) { + SaveMode.Overwrite + } else { + SaveMode.Append + } + val writer = new RedshiftWriter(jdbcWrapper, s3ClientFactory) + writer.saveToRedshift(sqlContext, data, saveMode, params) + } + + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined) + } + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + val creds = AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration) + for ( + redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl); + s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds)) + ) { + if (redshiftRegion != s3Region) { + // We don't currently support `extraunloadoptions`, so even if Amazon _did_ add a `region` + // option for this we wouldn't be able to pass in the new option. However, we choose to + // err on the side of caution and don't throw an exception because we don't want to break + // existing workloads in case the region detection logic is wrong. + log.error("The Redshift cluster and S3 bucket are in different regions " + + s"($redshiftRegion and $s3Region, respectively). Redshift's UNLOAD command requires " + + s"that the Redshift cluster and Amazon S3 bucket be located in the same region, so " + + s"this read will fail.") + } + } + Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds)) + if (requiredColumns.isEmpty) { + // In the special case where no columns were requested, issue a `count(*)` against Redshift + // rather than unloading data. + val whereClause = FilterPushdown.buildWhereClause(schema, filters) + val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause" + log.info(countQuery) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) + try { + val results = jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(countQuery)) + if (results.next()) { + val numRows = results.getLong(1) + val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt + val emptyRow = RowEncoder(StructType(Seq.empty)).toRow(Row(Seq.empty)) + sqlContext.sparkContext + .parallelize(1L to numRows, parallelism) + .map(_ => emptyRow) + .asInstanceOf[RDD[Row]] + } else { + throw new IllegalStateException("Could not read count from Redshift") + } + } finally { + conn.close() + } + } else { + // Unload data from Redshift into a temporary directory in S3: + val tempDir = params.createPerQueryTempDir() + val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds) + log.info(unloadSql) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) + try { + jdbcWrapper.executeInterruptibly(conn.prepareStatement(unloadSql)) + } finally { + conn.close() + } + // Read the MANIFEST file to get the list of S3 part files that were written by Redshift. + // We need to use a manifest in order to guard against S3's eventually-consistent listings. + val filesToRead: Seq[String] = { + val cleanedTempDirUri = + Utils.fixS3Url(Utils.removeCredentialsFromURI(URI.create(tempDir)).toString) + val s3URI = Utils.createS3URI(cleanedTempDirUri) + val s3Client = s3ClientFactory(creds) + val is = s3Client.getObject(s3URI.getBucket, s3URI.getKey + "manifest").getObjectContent + val s3Files = try { + val entries = Json.parse(new InputStreamReader(is)).asObject().get("entries").asArray() + entries.iterator().asScala.map(_.asObject().get("url").asString()).toSeq + } finally { + is.close() + } + // The filenames in the manifest are of the form s3://bucket/key, without credentials. + // If the S3 credentials were originally specified in the tempdir's URI, then we need to + // reintroduce them here + s3Files.map { file => + tempDir.stripSuffix("/") + '/' + file.stripPrefix(cleanedTempDirUri).stripPrefix("/") + } + } + + val prunedSchema = pruneSchema(schema, requiredColumns) + + sqlContext.read + .format(classOf[RedshiftFileFormat].getName) + .schema(prunedSchema) + .load(filesToRead: _*) + .queryExecution.executedPlan.execute().asInstanceOf[RDD[Row]] + } + } + + override def needConversion: Boolean = false + + private def buildUnloadStmt( + requiredColumns: Array[String], + filters: Array[Filter], + tempDir: String, + creds: AWSCredentialsProvider): String = { + assert(!requiredColumns.isEmpty) + // Always quote column names: + val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ") + val whereClause = FilterPushdown.buildWhereClause(schema, filters) + val credsString: String = + AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials) + val query = { + // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape + // any backslashes and single quotes that appear in the query itself + val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'") + s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause" + } + // We need to remove S3 credentials from the unload path URI because they will conflict with + // the credentials passed via `credsString`. + val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString) + + s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE MANIFEST" + } + + private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { + val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) + new StructType(columns.map(name => fieldMap(name))) + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala new file mode 100755 index 0000000000000..2b1d691782bb2 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -0,0 +1,425 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI +import java.sql.{Connection, Date, SQLException, Timestamp} + +import scala.collection.mutable +import scala.util.control.NonFatal + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.services.s3.AmazonS3Client +import com.databricks.spark.redshift.Parameters.MergedParameters +import org.apache.hadoop.fs.{FileSystem, Path} +import org.slf4j.LoggerFactory + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.types._ + +/** + * Functions to write data to Redshift. + * + * At a high level, writing data back to Redshift involves the following steps: + * + * - Use the spark-avro library to save the DataFrame to S3 using Avro serialization. Prior to + * saving the data, certain data type conversions are applied in order to work around + * limitations in Avro's data type support and Redshift's case-insensitive identifier handling. + * + * While writing the Avro files, we use accumulators to keep track of which partitions were + * non-empty. After the write operation completes, we use this to construct a list of non-empty + * Avro partition files. + * + * - If there is data to be written (i.e. not all partitions were empty), then use the list of + * non-empty Avro files to construct a JSON manifest file to tell Redshift to load those files. + * This manifest is written to S3 alongside the Avro files themselves. We need to use an + * explicit manifest, as opposed to simply passing the name of the directory containing the + * Avro files, in order to work around a bug related to parsing of empty Avro files (see #96). + * + * - Start a new JDBC transaction and disable auto-commit. Depending on the SaveMode, issue + * DELETE TABLE or CREATE TABLE commands, then use the COPY command to instruct Redshift to load + * the Avro data into the appropriate table. + */ +private[redshift] class RedshiftWriter( + jdbcWrapper: JDBCWrapper, + s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) { + + private val log = LoggerFactory.getLogger(getClass) + + /** + * Generate CREATE TABLE statement for Redshift + */ + // Visible for testing. + private[redshift] def createTableSql(data: DataFrame, params: MergedParameters): String = { + val schemaSql = jdbcWrapper.schemaString(data.schema) + val distStyleDef = params.distStyle match { + case Some(style) => s"DISTSTYLE $style" + case None => "" + } + val distKeyDef = params.distKey match { + case Some(key) => s"DISTKEY ($key)" + case None => "" + } + val sortKeyDef = params.sortKeySpec.getOrElse("") + val table = params.table.get + + s"CREATE TABLE IF NOT EXISTS $table ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef" + } + + /** + * Generate the COPY SQL command + */ + private def copySql( + sqlContext: SQLContext, + params: MergedParameters, + creds: AWSCredentialsProvider, + manifestUrl: String): String = { + val credsString: String = + AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials) + val fixedUrl = Utils.fixS3Url(manifestUrl) + val format = params.tempFormat match { + case "AVRO" => "AVRO 'auto'" + case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'" + } + s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " + + s"${format} manifest ${params.extraCopyOptions}" + } + + /** + * Generate COMMENT SQL statements for the table and columns. + */ + private[redshift] def commentActions(tableComment: Option[String], schema: StructType): + List[String] = { + tableComment.toList.map(desc => s"COMMENT ON TABLE %s IS '${desc.replace("'", "''")}'") ++ + schema.fields + .withFilter(f => f.metadata.contains("description")) + .map(f => s"""COMMENT ON COLUMN %s."${f.name.replace("\"", "\\\"")}"""" + + s" IS '${f.metadata.getString("description").replace("'", "''")}'") + } + + /** + * Perform the Redshift load by issuing a COPY statement. + */ + private def doRedshiftLoad( + conn: Connection, + data: DataFrame, + params: MergedParameters, + creds: AWSCredentialsProvider, + manifestUrl: Option[String]): Unit = { + + // If the table doesn't exist, we need to create it first, using JDBC to infer column types + val createStatement = createTableSql(data, params) + log.info(createStatement) + jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement)) + + val preActions = commentActions(params.description, data.schema) ++ params.preActions + + // Execute preActions + preActions.foreach { action => + val actionSql = if (action.contains("%s")) action.format(params.table.get) else action + log.info("Executing preAction: " + actionSql) + jdbcWrapper.executeInterruptibly(conn.prepareStatement(actionSql)) + } + + manifestUrl.foreach { manifestUrl => + // Load the temporary data into the new file + val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl) + log.info(copyStatement) + try { + jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement)) + } catch { + case e: SQLException => + log.error("SQLException thrown while running COPY query; will attempt to retrieve " + + "more information by querying the STL_LOAD_ERRORS table", e) + // Try to query Redshift's STL_LOAD_ERRORS table to figure out why the load failed. + // See http://docs.aws.amazon.com/redshift/latest/dg/r_STL_LOAD_ERRORS.html for details. + conn.rollback() + val errorLookupQuery = + """ + | SELECT * + | FROM stl_load_errors + | WHERE query = pg_last_query_id() + """.stripMargin + val detailedException: Option[SQLException] = try { + val results = + jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(errorLookupQuery)) + if (results.next()) { + val errCode = results.getInt("err_code") + val errReason = results.getString("err_reason").trim + val columnLength: String = + Option(results.getString("col_length")) + .map(_.trim) + .filter(_.nonEmpty) + .map(n => s"($n)") + .getOrElse("") + val exceptionMessage = + s""" + |Error (code $errCode) while loading data into Redshift: "$errReason" + |Table name: ${params.table.get} + |Column name: ${results.getString("colname").trim} + |Column type: ${results.getString("type").trim}$columnLength + |Raw line: ${results.getString("raw_line")} + |Raw field value: ${results.getString("raw_field_value")} + """.stripMargin + Some(new SQLException(exceptionMessage, e)) + } else { + None + } + } catch { + case NonFatal(e2) => + log.error("Error occurred while querying STL_LOAD_ERRORS", e2) + None + } + throw detailedException.getOrElse(e) + } + } + + // Execute postActions + params.postActions.foreach { action => + val actionSql = if (action.contains("%s")) action.format(params.table.get) else action + log.info("Executing postAction: " + actionSql) + jdbcWrapper.executeInterruptibly(conn.prepareStatement(actionSql)) + } + } + + /** + * Serialize temporary data to S3, ready for Redshift COPY, and create a manifest file which can + * be used to instruct Redshift to load the non-empty temporary data partitions. + * + * @return the URL of the manifest file in S3, in `s3://path/to/file/manifest.json` format, if + * at least one record was written, and None otherwise. + */ + private def unloadData( + sqlContext: SQLContext, + data: DataFrame, + tempDir: String, + tempFormat: String, + nullString: String): Option[String] = { + // spark-avro does not support Date types. In addition, it converts Timestamps into longs + // (milliseconds since the Unix epoch). Redshift is capable of loading timestamps in + // 'epochmillisecs' format but there's no equivalent format for dates. To work around this, we + // choose to write out both dates and timestamps as strings. + // For additional background and discussion, see #39. + + // Convert the rows so that timestamps and dates become formatted strings. + // Formatters are not thread-safe, and thus these functions are not thread-safe. + // However, each task gets its own deserialized copy, making this safe. + val conversionFunctions: Array[Any => Any] = data.schema.fields.map { field => + field.dataType match { + case DateType => + val dateFormat = Conversions.createRedshiftDateFormat() + (v: Any) => { + if (v == null) null else dateFormat.format(v.asInstanceOf[Date]) + } + case TimestampType => + val timestampFormat = Conversions.createRedshiftTimestampFormat() + (v: Any) => { + if (v == null) null else timestampFormat.format(v.asInstanceOf[Timestamp]) + } + case _ => (v: Any) => v + } + } + + // Use Spark accumulators to determine which partitions were non-empty. + val nonEmptyPartitions = + sqlContext.sparkContext.accumulableCollection(mutable.HashSet.empty[Int]) + + val convertedRows: RDD[Row] = data.rdd.mapPartitions { iter: Iterator[Row] => + if (iter.hasNext) { + nonEmptyPartitions += TaskContext.get.partitionId() + } + iter.map { row => + val convertedValues: Array[Any] = new Array(conversionFunctions.length) + var i = 0 + while (i < conversionFunctions.length) { + convertedValues(i) = conversionFunctions(i)(row(i)) + i += 1 + } + Row.fromSeq(convertedValues) + } + } + + // Convert all column names to lowercase, which is necessary for Redshift to be able to load + // those columns (see #51). + val schemaWithLowercaseColumnNames: StructType = + StructType(data.schema.map(f => f.copy(name = f.name.toLowerCase))) + + if (schemaWithLowercaseColumnNames.map(_.name).toSet.size != data.schema.size) { + throw new IllegalArgumentException( + "Cannot save table to Redshift because two or more column names would be identical" + + " after conversion to lowercase: " + data.schema.map(_.name).mkString(", ")) + } + + // Update the schema so that Avro writes date and timestamp columns as formatted timestamp + // strings. This is necessary for Redshift to be able to load these columns (see #39). + val convertedSchema: StructType = StructType( + schemaWithLowercaseColumnNames.map { + case StructField(name, DateType, nullable, meta) => + StructField(name, StringType, nullable, meta) + case StructField(name, TimestampType, nullable, meta) => + StructField(name, StringType, nullable, meta) + case other => other + } + ) + + val writer = sqlContext.createDataFrame(convertedRows, convertedSchema).write + (tempFormat match { + case "AVRO" => + writer.format("com.databricks.spark.avro") + case "CSV" => + writer.format("csv") + .option("escape", "\"") + .option("nullValue", nullString) + case "CSV GZIP" => + writer.format("csv") + .option("escape", "\"") + .option("nullValue", nullString) + .option("compression", "gzip") + }).save(tempDir) + + if (nonEmptyPartitions.value.isEmpty) { + None + } else { + // See https://docs.aws.amazon.com/redshift/latest/dg/loading-data-files-using-manifest.html + // for a description of the manifest file format. The URLs in this manifest must be absolute + // and complete. + + // The partition filenames are of the form part-r-XXXXX-UUID.fileExtension. + val fs = FileSystem.get(URI.create(tempDir), sqlContext.sparkContext.hadoopConfiguration) + val partitionIdRegex = "^part-(?:r-)?(\\d+)[^\\d+].*$".r + val filesToLoad: Seq[String] = { + val nonEmptyPartitionIds = nonEmptyPartitions.value.toSet + fs.listStatus(new Path(tempDir)).map(_.getPath.getName).collect { + case file @ partitionIdRegex(id) if nonEmptyPartitionIds.contains(id.toInt) => file + } + } + // It's possible that tempDir contains AWS access keys. We shouldn't save those credentials to + // S3, so let's first sanitize `tempdir` and make sure that it uses the s3:// scheme: + val sanitizedTempDir = Utils.fixS3Url( + Utils.removeCredentialsFromURI(URI.create(tempDir)).toString).stripSuffix("/") + val manifestEntries = filesToLoad.map { file => + s"""{"url":"$sanitizedTempDir/$file", "mandatory":true}""" + } + val manifest = s"""{"entries": [${manifestEntries.mkString(",\n")}]}""" + val manifestPath = sanitizedTempDir + "/manifest.json" + val fsDataOut = fs.create(new Path(manifestPath)) + try { + fsDataOut.write(manifest.getBytes("utf-8")) + } finally { + fsDataOut.close() + } + Some(manifestPath) + } + } + + /** + * Write a DataFrame to a Redshift table, using S3 and Avro or CSV serialization + */ + def saveToRedshift( + sqlContext: SQLContext, + data: DataFrame, + saveMode: SaveMode, + params: MergedParameters) : Unit = { + if (params.table.isEmpty) { + throw new IllegalArgumentException( + "For save operations you must specify a Redshift table name with the 'dbtable' parameter") + } + + if (!params.useStagingTable) { + log.warn("Setting useStagingTable=false is deprecated; instead, we recommend that you " + + "drop the target table yourself. For more details on this deprecation, see" + + "https://github.com/databricks/spark-redshift/pull/157") + } + + val creds: AWSCredentialsProvider = + AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration) + + for ( + redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl); + s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds)) + ) { + val regionIsSetInExtraCopyOptions = + params.extraCopyOptions.contains(s3Region) && params.extraCopyOptions.contains("region") + if (redshiftRegion != s3Region && !regionIsSetInExtraCopyOptions) { + log.error("The Redshift cluster and S3 bucket are in different regions " + + s"($redshiftRegion and $s3Region, respectively). In order to perform this cross-region " + + s"""write, you must add "region '$s3Region'" to the extracopyoptions parameter. """ + + "For more details on cross-region usage, see the README.") + } + } + + // When using the Avro tempformat, log an informative error message in case any column names + // are unsupported by Avro's schema validation: + if (params.tempFormat == "AVRO") { + for (fieldName <- data.schema.fieldNames) { + // The following logic is based on Avro's Schema.validateName() method: + val firstChar = fieldName.charAt(0) + val isValid = (firstChar.isLetter || firstChar == '_') && fieldName.tail.forall { c => + c.isLetterOrDigit || c == '_' + } + if (!isValid) { + throw new IllegalArgumentException( + s"The field name '$fieldName' is not supported when using the Avro tempformat. " + + "Try using the CSV tempformat instead. For more details, see " + + "https://github.com/databricks/spark-redshift/issues/84") + } + } + } + + Utils.assertThatFileSystemIsNotS3BlockFileSystem( + new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration) + + Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds)) + + // Save the table's rows to S3: + val manifestUrl = unloadData( + sqlContext, + data, + tempDir = params.createPerQueryTempDir(), + tempFormat = params.tempFormat, + nullString = params.nullString) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) + conn.setAutoCommit(false) + try { + val table: TableName = params.table.get + if (saveMode == SaveMode.Overwrite) { + // Overwrites must drop the table in case there has been a schema update + jdbcWrapper.executeInterruptibly(conn.prepareStatement(s"DROP TABLE IF EXISTS $table;")) + if (!params.useStagingTable) { + // If we're not using a staging table, commit now so that Redshift doesn't have to + // maintain a snapshot of the old table during the COPY; this sacrifices atomicity for + // performance. + conn.commit() + } + } + log.info(s"Loading new Redshift data to: $table") + doRedshiftLoad(conn, data, params, creds, manifestUrl) + conn.commit() + } catch { + case NonFatal(e) => + try { + log.error("Exception thrown during Redshift load; will roll back transaction", e) + conn.rollback() + } catch { + case NonFatal(e2) => + log.error("Exception while rolling back transaction", e2) + } + throw e + } finally { + conn.close() + } + } +} + +object DefaultRedshiftWriter extends RedshiftWriter( + DefaultJDBCWrapper, + awsCredentials => new AmazonS3Client(awsCredentials)) diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala new file mode 100755 index 0000000000000..74e3d351e4e31 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/SerializableConfiguration.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io._ + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.hadoop.conf.Configuration +import org.slf4j.LoggerFactory + +class SerializableConfiguration(@transient var value: Configuration) + extends Serializable with KryoSerializable { + @transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + value.write(dos) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + value = new Configuration(false) + value.readFields(new DataInputStream(in)) + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala new file mode 100755 index 0000000000000..77ef3b6602b22 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/TableName.scala @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import scala.collection.mutable.ArrayBuffer + +/** + * Wrapper class for representing the name of a Redshift table. + */ +private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) { + private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' + def escapedSchemaName: String = quote(unescapedSchemaName) + def escapedTableName: String = quote(unescapedTableName) + override def toString: String = s"$escapedSchemaName.$escapedTableName" +} + +private[redshift] object TableName { + /** + * Parses a table name which is assumed to have been escaped according to Redshift's rules for + * delimited identifiers. + */ + def parseFromEscaped(str: String): TableName = { + def dropOuterQuotes(s: String) = + if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s + def unescapeQuotes(s: String) = s.replace("\"\"", "\"") + def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s)) + splitByDots(str) match { + case Seq(tableName) => TableName("PUBLIC", unescape(tableName)) + case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName)) + case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'") + } + } + + /** + * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear + * inside of quoted identifiers. + */ + private def splitByDots(str: String): Seq[String] = { + val parts: ArrayBuffer[String] = ArrayBuffer.empty + val sb = new StringBuilder + var inQuotes: Boolean = false + for (c <- str) c match { + case '"' => + // Note that double quotes are escaped by pairs of double quotes (""), so we don't need + // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair. + sb.append('"') + inQuotes = !inQuotes + case '.' => + if (!inQuotes) { + parts.append(sb.toString()) + sb.clear() + } else { + sb.append('.') + } + case other => + sb.append(other) + } + if (sb.nonEmpty) { + parts.append(sb.toString()) + } + parts + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala new file mode 100755 index 0000000000000..ae5e4c9c9b78f --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/Utils.scala @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.amazonaws.services.s3.{AmazonS3Client, AmazonS3URI} +import com.amazonaws.services.s3.model.BucketLifecycleConfiguration +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.slf4j.LoggerFactory + +/** + * Various arbitrary helper functions + */ +private[redshift] object Utils { + + private val log = LoggerFactory.getLogger(getClass) + + def classForName(className: String): Class[_] = { + val classLoader = + Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader) + // scalastyle:off + Class.forName(className, true, classLoader) + // scalastyle:on + } + + /** + * Joins prefix URL a to path suffix b, and appends a trailing /, in order to create + * a temp directory path for S3. + */ + def joinUrls(a: String, b: String): String = { + a.stripSuffix("/") + "/" + b.stripPrefix("/").stripSuffix("/") + "/" + } + + /** + * Redshift COPY and UNLOAD commands don't support s3n or s3a, but users may wish to use them + * for data loads. This function converts the URL back to the s3:// format. + */ + def fixS3Url(url: String): String = { + url.replaceAll("s3[an]://", "s3://") + } + + /** + * Factory method to create new S3URI in order to handle various library incompatibilities with + * older AWS Java Libraries + */ + def createS3URI(url: String): AmazonS3URI = { + try { + // try to instantiate AmazonS3URI with url + new AmazonS3URI(url) + } catch { + case e: IllegalArgumentException if e.getMessage. + startsWith("Invalid S3 URI: hostname does not appear to be a valid S3 endpoint") => { + new AmazonS3URI(addEndpointToUrl(url)) + } + } + } + + /** + * Since older AWS Java Libraries do not handle S3 urls that have just the bucket name + * as the host, add the endpoint to the host + */ + def addEndpointToUrl(url: String, domain: String = "s3.amazonaws.com"): String = { + val uri = new URI(url) + val hostWithEndpoint = uri.getHost + "." + domain + new URI(uri.getScheme, + uri.getUserInfo, + hostWithEndpoint, + uri.getPort, + uri.getPath, + uri.getQuery, + uri.getFragment).toString + } + + /** + * Returns a copy of the given URI with the user credentials removed. + */ + def removeCredentialsFromURI(uri: URI): URI = { + new URI( + uri.getScheme, + null, // no user info + uri.getHost, + uri.getPort, + uri.getPath, + uri.getQuery, + uri.getFragment) + } + + // Visible for testing + private[redshift] var lastTempPathGenerated: String = null + + /** + * Creates a randomly named temp directory path for intermediate data + */ + def makeTempPath(tempRoot: String): String = { + lastTempPathGenerated = Utils.joinUrls(tempRoot, UUID.randomUUID().toString) + lastTempPathGenerated + } + + /** + * Checks whether the S3 bucket for the given UI has an object lifecycle configuration to + * ensure cleanup of temporary files. If no applicable configuration is found, this method logs + * a helpful warning for the user. + */ + def checkThatBucketHasObjectLifecycleConfiguration( + tempDir: String, + s3Client: AmazonS3Client): Unit = { + try { + val s3URI = createS3URI(Utils.fixS3Url(tempDir)) + val bucket = s3URI.getBucket + assert(bucket != null, "Could not get bucket from S3 URI") + val key = Option(s3URI.getKey).getOrElse("") + val hasMatchingBucketLifecycleRule: Boolean = { + val rules = Option(s3Client.getBucketLifecycleConfiguration(bucket)) + .map(_.getRules.asScala) + .getOrElse(Seq.empty) + rules.exists { rule => + // Note: this only checks that there is an active rule which matches the temp directory; + // it does not actually check that the rule will delete the files. This check is still + // better than nothing, though, and we can always improve it later. + rule.getStatus == BucketLifecycleConfiguration.ENABLED && key.startsWith(rule.getPrefix) + } + } + if (!hasMatchingBucketLifecycleRule) { + log.warn(s"The S3 bucket $bucket does not have an object lifecycle configuration to " + + "ensure cleanup of temporary files. Consider configuring `tempdir` to point to a " + + "bucket with an object lifecycle policy that automatically deletes files after an " + + "expiration period. For more information, see " + + "https://docs.aws.amazon.com/AmazonS3/latest/dev/object-lifecycle-mgmt.html") + } + } catch { + case NonFatal(e) => + log.warn("An error occurred while trying to read the S3 bucket lifecycle configuration", e) + } + } + + /** + * Given a URI, verify that the Hadoop FileSystem for that URI is not the S3 block FileSystem. + * `spark-redshift` cannot use this FileSystem because the files written to it will not be + * readable by Redshift (and vice versa). + */ + def assertThatFileSystemIsNotS3BlockFileSystem(uri: URI, hadoopConfig: Configuration): Unit = { + val fs = FileSystem.get(uri, hadoopConfig) + // Note that we do not want to use isInstanceOf here, since we're only interested in detecting + // exact matches. We compare the class names as strings in order to avoid introducing a binary + // dependency on classes which belong to the `hadoop-aws` JAR, as that artifact is not present + // in some environments (such as EMR). See #92 for details. + if (fs.getClass.getCanonicalName == "org.apache.hadoop.fs.s3.S3FileSystem") { + throw new IllegalArgumentException( + "spark-redshift does not support the S3 Block FileSystem. Please reconfigure `tempdir` to" + + "use a s3n:// or s3a:// scheme.") + } + } + + /** + * Attempts to retrieve the region of the S3 bucket. + */ + def getRegionForS3Bucket(tempDir: String, s3Client: AmazonS3Client): Option[String] = { + try { + val s3URI = createS3URI(Utils.fixS3Url(tempDir)) + val bucket = s3URI.getBucket + assert(bucket != null, "Could not get bucket from S3 URI") + val region = s3Client.getBucketLocation(bucket) match { + // Map "US Standard" to us-east-1 + case null | "US" => "us-east-1" + case other => other + } + Some(region) + } catch { + case NonFatal(e) => + log.warn("An error occurred while trying to determine the S3 bucket's region", e) + None + } + } + + /** + * Attempts to determine the region of a Redshift cluster based on its URL. It may not be possible + * to determine the region in some cases, such as when the Redshift cluster is placed behind a + * proxy. + */ + def getRegionForRedshiftCluster(url: String): Option[String] = { + val regionRegex = """.*\.([^.]+)\.redshift\.amazonaws\.com.*""".r + url match { + case regionRegex(region) => Some(region) + case _ => None + } + } +} diff --git a/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala b/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala new file mode 100755 index 0000000000000..233141efa23a7 --- /dev/null +++ b/external/redshift/src/main/scala/com/databricks/spark/redshift/package.scala @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +package object redshift { + + /** + * Wrapper of SQLContext that provide `redshiftFile` method. + */ + implicit class RedshiftContext(sqlContext: SQLContext) { + + /** + * Read a file unloaded from Redshift into a DataFrame. + * @param path input path + * @return a DataFrame with all string columns + */ + def redshiftFile(path: String, columns: Seq[String]): DataFrame = { + val sc = sqlContext.sparkContext + val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat], + classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration) + // TODO: allow setting NULL string. + val nullable = rdd.values.map(_.map(f => if (f.isEmpty) null else f)).map(x => Row(x: _*)) + val schema = StructType(columns.map(c => StructField(c, StringType, nullable = true))) + sqlContext.createDataFrame(nullable, schema) + } + + /** + * Reads a table unload from Redshift with its schema. + */ + def redshiftFile(path: String, schema: StructType): DataFrame = { + val casts = schema.fields.map { field => + col(field.name).cast(field.dataType).as(field.name) + } + redshiftFile(path, schema.fieldNames).select(casts: _*) + } + } +} diff --git a/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java b/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java new file mode 100755 index 0000000000000..b1a6a08a0decb --- /dev/null +++ b/external/redshift/src/test/java/com/databricks/spark/redshift/S3NInMemoryFileSystem.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package org.apache.hadoop.fs.s3native; + +import org.apache.hadoop.fs.s3native.NativeS3FileSystem; +import org.apache.hadoop.fs.s3native.InMemoryNativeFileSystemStore; + +/** + * A helper implementation of {@link NativeS3FileSystem} + * without actually connecting to S3 for unit testing. + */ +public class S3NInMemoryFileSystem extends NativeS3FileSystem { + public S3NInMemoryFileSystem() { + super(new InMemoryNativeFileSystemStore()); + } +} diff --git a/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java b/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java new file mode 100644 index 0000000000000..ac572aad40361 --- /dev/null +++ b/external/redshift/src/test/java/org/apache/hadoop/fs/s3native/InMemoryNativeFileSystemStore.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.fs.s3native; + +import static org.apache.hadoop.fs.s3native.NativeS3FileSystem.PATH_DELIMITER; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.Map.Entry; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.util.Time; + +/** + *

+ * A stub implementation of {@link NativeFileSystemStore} for testing + * {@link NativeS3FileSystem} without actually connecting to S3. + *

+ */ +public class InMemoryNativeFileSystemStore implements NativeFileSystemStore { + + private Configuration conf; + + private SortedMap metadataMap = + new TreeMap(); + private SortedMap dataMap = new TreeMap(); + + @Override + public void initialize(URI uri, Configuration conf) throws IOException { + this.conf = conf; + } + + @Override + public void storeEmptyFile(String key) throws IOException { + metadataMap.put(key, new FileMetadata(key, 0, Time.now())); + dataMap.put(key, new byte[0]); + } + + @Override + public void storeFile(String key, File file, byte[] md5Hash) + throws IOException { + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + byte[] buf = new byte[8192]; + int numRead; + BufferedInputStream in = null; + try { + in = new BufferedInputStream(new FileInputStream(file)); + while ((numRead = in.read(buf)) >= 0) { + out.write(buf, 0, numRead); + } + } finally { + if (in != null) { + in.close(); + } + } + metadataMap.put(key, + new FileMetadata(key, file.length(), Time.now())); + dataMap.put(key, out.toByteArray()); + } + + @Override + public InputStream retrieve(String key) throws IOException { + return retrieve(key, 0); + } + + @Override + public InputStream retrieve(String key, long byteRangeStart) + throws IOException { + + byte[] data = dataMap.get(key); + File file = createTempFile(); + BufferedOutputStream out = null; + try { + out = new BufferedOutputStream(new FileOutputStream(file)); + out.write(data, (int) byteRangeStart, + data.length - (int) byteRangeStart); + } finally { + if (out != null) { + out.close(); + } + } + return new FileInputStream(file); + } + + private File createTempFile() throws IOException { + File dir = new File(conf.get("fs.s3.buffer.dir")); + if (!dir.exists() && !dir.mkdirs()) { + throw new IOException("Cannot create S3 buffer directory: " + dir); + } + File result = File.createTempFile("test-", ".tmp", dir); + result.deleteOnExit(); + return result; + } + + @Override + public FileMetadata retrieveMetadata(String key) throws IOException { + return metadataMap.get(key); + } + + @Override + public PartialListing list(String prefix, int maxListingLength) + throws IOException { + return list(prefix, maxListingLength, null, false); + } + + @Override + public PartialListing list(String prefix, int maxListingLength, + String priorLastKey, boolean recursive) throws IOException { + + return list(prefix, recursive ? null : PATH_DELIMITER, maxListingLength, priorLastKey); + } + + private PartialListing list(String prefix, String delimiter, + int maxListingLength, String priorLastKey) throws IOException { + + if (prefix.length() > 0 && !prefix.endsWith(PATH_DELIMITER)) { + prefix += PATH_DELIMITER; + } + + List metadata = new ArrayList(); + SortedSet commonPrefixes = new TreeSet(); + for (String key : dataMap.keySet()) { + if (key.startsWith(prefix)) { + if (delimiter == null) { + metadata.add(retrieveMetadata(key)); + } else { + int delimIndex = key.indexOf(delimiter, prefix.length()); + if (delimIndex == -1) { + metadata.add(retrieveMetadata(key)); + } else { + String commonPrefix = key.substring(0, delimIndex); + commonPrefixes.add(commonPrefix); + } + } + } + if (metadata.size() + commonPrefixes.size() == maxListingLength) { + new PartialListing(key, metadata.toArray(new FileMetadata[0]), + commonPrefixes.toArray(new String[0])); + } + } + return new PartialListing(null, metadata.toArray(new FileMetadata[0]), + commonPrefixes.toArray(new String[0])); + } + + @Override + public void delete(String key) throws IOException { + metadataMap.remove(key); + dataMap.remove(key); + } + + @Override + public void copy(String srcKey, String dstKey) throws IOException { + metadataMap.put(dstKey, metadataMap.get(srcKey)); + dataMap.put(dstKey, dataMap.get(srcKey)); + } + + @Override + public void purge(String prefix) throws IOException { + Iterator> i = + metadataMap.entrySet().iterator(); + while (i.hasNext()) { + Entry entry = i.next(); + if (entry.getKey().startsWith(prefix)) { + dataMap.remove(entry.getKey()); + i.remove(); + } + } + } + + @Override + public void dump() throws IOException { + System.out.println(metadataMap.values()); + System.out.println(dataMap.keySet()); + } +} diff --git a/external/redshift/src/test/resources/hive-site.xml b/external/redshift/src/test/resources/hive-site.xml new file mode 100755 index 0000000000000..1a06bec091434 --- /dev/null +++ b/external/redshift/src/test/resources/hive-site.xml @@ -0,0 +1,22 @@ + + + + + + + fs.permissions.umask-mode + 022 + Setting a value for fs.permissions.umask-mode to work around issue in HIVE-6962. + It has no impact in Hadoop 1.x line on HDFS operations. + + + diff --git a/external/redshift/src/test/resources/redshift_unload_data.txt b/external/redshift/src/test/resources/redshift_unload_data.txt new file mode 100755 index 0000000000000..6b47543ba9b97 --- /dev/null +++ b/external/redshift/src/test/resources/redshift_unload_data.txt @@ -0,0 +1,5 @@ +1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001 +1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0 +0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00 +0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123| +||||||||| diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala new file mode 100755 index 0000000000000..91f8d78d8f840 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/AWSCredentialsUtilsSuite.scala @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import scala.language.implicitConversions + +import com.amazonaws.auth.{AWSSessionCredentials, BasicAWSCredentials, BasicSessionCredentials} +import com.databricks.spark.redshift.Parameters.MergedParameters +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkFunSuite + +class AWSCredentialsUtilsSuite extends SparkFunSuite { + + val baseParams = Map( + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_schema.test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password") + + private implicit def string2Params(tempdir: String): MergedParameters = { + Parameters.mergeParameters(baseParams ++ Map( + "tempdir" -> tempdir, + "forward_spark_s3_credentials" -> "true")) + } + + test("credentialsString with regular keys") { + val creds = new BasicAWSCredentials("ACCESSKEYID", "SECRET/KEY/WITH/SLASHES") + val params = + Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true")) + assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) === + "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES") + } + + test("credentialsString with STS temporary keys") { + val params = Parameters.mergeParameters(baseParams ++ Map( + "temporary_aws_access_key_id" -> "ACCESSKEYID", + "temporary_aws_secret_access_key" -> "SECRET/KEY", + "temporary_aws_session_token" -> "SESSION/Token")) + assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) === + "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token") + } + + test("Configured IAM roles should take precedence") { + val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token") + val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role" + val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole)) + assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) === + s"aws_iam_role=$iamRole") + } + + test("AWSCredentials.load() STS temporary keys should take precedence") { + val conf = new Configuration(false) + conf.set("fs.s3.awsAccessKeyId", "CONFID") + conf.set("fs.s3.awsSecretAccessKey", "CONFKEY") + + val params = Parameters.mergeParameters(baseParams ++ Map( + "tempdir" -> "s3://URIID:URIKEY@bucket/path", + "temporary_aws_access_key_id" -> "key_id", + "temporary_aws_secret_access_key" -> "secret", + "temporary_aws_session_token" -> "token" + )) + + val creds = AWSCredentialsUtils.load(params, conf).getCredentials + assert(creds.isInstanceOf[AWSSessionCredentials]) + assert(creds.getAWSAccessKeyId === "key_id") + assert(creds.getAWSSecretKey === "secret") + assert(creds.asInstanceOf[AWSSessionCredentials].getSessionToken === "token") + } + + test("AWSCredentials.load() credentials precedence for s3:// URIs") { + val conf = new Configuration(false) + conf.set("fs.s3.awsAccessKeyId", "CONFID") + conf.set("fs.s3.awsSecretAccessKey", "CONFKEY") + + { + val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "URIID") + assert(creds.getAWSSecretKey === "URIKEY") + } + + { + val creds = AWSCredentialsUtils.load("s3://bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "CONFID") + assert(creds.getAWSSecretKey === "CONFKEY") + } + + } + + test("AWSCredentials.load() credentials precedence for s3n:// URIs") { + val conf = new Configuration(false) + conf.set("fs.s3n.awsAccessKeyId", "CONFID") + conf.set("fs.s3n.awsSecretAccessKey", "CONFKEY") + + { + val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "URIID") + assert(creds.getAWSSecretKey === "URIKEY") + } + + { + val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "CONFID") + assert(creds.getAWSSecretKey === "CONFKEY") + } + + } + + test("AWSCredentials.load() credentials precedence for s3a:// URIs") { + val conf = new Configuration(false) + conf.set("fs.s3a.access.key", "CONFID") + conf.set("fs.s3a.secret.key", "CONFKEY") + + { + val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "URIID") + assert(creds.getAWSSecretKey === "URIKEY") + } + + { + val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf).getCredentials + assert(creds.getAWSAccessKeyId === "CONFID") + assert(creds.getAWSSecretKey === "CONFKEY") + } + + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala new file mode 100755 index 0000000000000..41817311676c9 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.Timestamp +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +/** + * Unit test for data type conversions + */ +class ConversionsSuite extends SparkFunSuite { + + private def createRowConverter(schema: StructType) = { + Conversions.createRowConverter(schema).andThen(RowEncoder(schema).resolveAndBind().fromRow) + } + + test("Data should be correctly converted") { + val convertRow = createRowConverter(TestUtils.testSchema) + val doubleMin = Double.MinValue.toString + val longMax = Long.MaxValue.toString + // scalastyle:off + val unicodeString = "Unicode是樂趣" + // scalastyle:on + + val timestampWithMillis = "2014-03-01 00:00:01.123" + + val expectedDateMillis = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) + val expectedTimestampMillis = TestUtils.toMillis(2014, 2, 1, 0, 0, 1, 123) + + val convertedRow = convertRow( + Array("1", "t", "2015-07-01", doubleMin, "1.0", "42", + longMax, "23", unicodeString, timestampWithMillis)) + + val expectedRow = Row(1.asInstanceOf[Byte], true, new Timestamp(expectedDateMillis), + Double.MinValue, 1.0f, 42, Long.MaxValue, 23.toShort, unicodeString, + new Timestamp(expectedTimestampMillis)) + + assert(convertedRow == expectedRow) + } + + test("Row conversion handles null values") { + val convertRow = createRowConverter(TestUtils.testSchema) + val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] + assert(convertRow(emptyRow) === Row(emptyRow: _*)) + } + + test("Booleans are correctly converted") { + val convertRow = createRowConverter(StructType(Seq(StructField("a", BooleanType)))) + assert(convertRow(Array("t")) === Row(true)) + assert(convertRow(Array("f")) === Row(false)) + assert(convertRow(Array(null)) === Row(null)) + intercept[IllegalArgumentException] { + convertRow(Array("not-a-boolean")) + } + } + + test("timestamp conversion handles millisecond-level precision (regression test for #214)") { + val schema = StructType(Seq(StructField("a", TimestampType))) + val convertRow = createRowConverter(schema) + Seq( + "2014-03-01 00:00:01" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1000), + "2014-03-01 00:00:01.000" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1000), + "2014-03-01 00:00:00.1" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100), + "2014-03-01 00:00:00.10" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100), + "2014-03-01 00:00:00.100" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 100), + "2014-03-01 00:00:00.01" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 10), + "2014-03-01 00:00:00.010" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 10), + "2014-03-01 00:00:00.001" -> TestUtils.toMillis(2014, 2, 1, 0, 0, 0, millis = 1) + ).foreach { case (timestampString, expectedTime) => + withClue(s"timestamp string is '$timestampString'") { + val convertedRow = convertRow(Array(timestampString)) + val convertedTimestamp = convertedRow.get(0).asInstanceOf[Timestamp] + assert(convertedTimestamp === new Timestamp(expectedTime)) + } + } + } + + test("RedshiftDecimalFormat is locale-insensitive (regression test for #243)") { + for (locale <- Seq(Locale.US, Locale.GERMAN, Locale.UK)) { + withClue(s"locale = $locale") { + TestUtils.withDefaultLocale(locale) { + val decimalFormat = Conversions.createRedshiftDecimalFormat() + val parsed = decimalFormat.parse("151.20").asInstanceOf[java.math.BigDecimal] + assert(parsed.doubleValue() === 151.20) + } + } + } + } + + test("Row conversion properly handles NaN and Inf float values (regression test for #261)") { + val convertRow = createRowConverter(StructType(Seq(StructField("a", FloatType)))) + assert(java.lang.Float.isNaN(convertRow(Array("nan")).getFloat(0))) + assert(convertRow(Array("inf")) === Row(Float.PositiveInfinity)) + assert(convertRow(Array("-inf")) === Row(Float.NegativeInfinity)) + } + + test("Row conversion properly handles NaN and Inf double values (regression test for #261)") { + val convertRow = createRowConverter(StructType(Seq(StructField("a", DoubleType)))) + assert(java.lang.Double.isNaN(convertRow(Array("nan")).getDouble(0))) + assert(convertRow(Array("inf")) === Row(Double.PositiveInfinity)) + assert(convertRow(Array("-inf")) === Row(Double.NegativeInfinity)) + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala new file mode 100755 index 0000000000000..a24dd001920ec --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapredOutputCommitter.scala @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred._ + +class DirectMapredOutputCommitter extends OutputCommitter { + override def setupJob(jobContext: JobContext): Unit = { } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { } + + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { + // We return true here to guard against implementations that do not handle false correctly. + // The meaning of returning false is not entirely clear, so it's possible to be interpreted + // as an error. Returning true just means that commitTask() will be called, which is a no-op. + true + } + + override def commitTask(taskContext: TaskAttemptContext): Unit = { } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { } + + /** + * Creates a _SUCCESS file to indicate the entire job was successful. + * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. + */ + override def commitJob(context: JobContext): Unit = { + val conf = context.getJobConf + if (shouldCreateSuccessFile(conf)) { + val outputPath = FileOutputFormat.getOutputPath(conf) + if (outputPath != null) { + val fileSys = outputPath.getFileSystem(conf) + val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSys.create(filePath).close() + } + } + } + + /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ + private def shouldCreateSuccessFile(conf: JobConf): Boolean = { + conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala new file mode 100755 index 0000000000000..922706d30ad3b --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/DirectMapreduceOutputCommitter.scala @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} + +class DirectMapreduceOutputCommitter extends OutputCommitter { + override def setupJob(jobContext: JobContext): Unit = { } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { } + + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { + // We return true here to guard against implementations that do not handle false correctly. + // The meaning of returning false is not entirely clear, so it's possible to be interpreted + // as an error. Returning true just means that commitTask() will be called, which is a no-op. + true + } + + override def commitTask(taskContext: TaskAttemptContext): Unit = { } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { } + + /** + * Creates a _SUCCESS file to indicate the entire job was successful. + * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. + */ + override def commitJob(context: JobContext): Unit = { + val conf = context.getConfiguration + if (shouldCreateSuccessFile(conf)) { + val outputPath = FileOutputFormat.getOutputPath(context) + if (outputPath != null) { + val fileSys = outputPath.getFileSystem(conf) + val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSys.create(filePath).close() + } + } + } + + /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ + private def shouldCreateSuccessFile(conf: Configuration): Boolean = { + conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala new file mode 100755 index 0000000000000..d3e09d5e8a733 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/FilterPushdownSuite.scala @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import com.databricks.spark.redshift.FilterPushdown._ + +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.SparkFunSuite + +class FilterPushdownSuite extends SparkFunSuite { + test("buildWhereClause with empty list of filters") { + assert(buildWhereClause(StructType(Nil), Seq.empty) === "") + } + + test("buildWhereClause with no filters that can be pushed down") { + assert(buildWhereClause(StructType(Nil), Seq(NewFilter, NewFilter)) === "") + } + + test("buildWhereClause with with some filters that cannot be pushed down") { + val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), NewFilter)) + assert(whereClause === """WHERE "test_int" = 1""") + } + + test("buildWhereClause with string literals that contain Unicode characters") { + // scalastyle:off + val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣"))) + // Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we + // also need to escape those quotes with backslashes because this WHERE clause is going to + // eventually be embedded inside of a single-quoted string that's embedded inside of a larger + // Redshift query. + assert(whereClause === """WHERE "test_string" = \'Unicode\'\'s樂趣\'""") + // scalastyle:on + } + + test("buildWhereClause with multiple filters") { + val filters = Seq( + EqualTo("test_bool", true), + // scalastyle:off + EqualTo("test_string", "Unicode是樂趣"), + // scalastyle:on + GreaterThan("test_double", 1000.0), + LessThan("test_double", Double.MaxValue), + GreaterThanOrEqual("test_float", 1.0f), + LessThanOrEqual("test_int", 43), + IsNotNull("test_int"), + IsNull("test_int")) + val whereClause = buildWhereClause(testSchema, filters) + // scalastyle:off + val expectedWhereClause = + """ + |WHERE "test_bool" = true + |AND "test_string" = \'Unicode是樂趣\' + |AND "test_double" > 1000.0 + |AND "test_double" < 1.7976931348623157E308 + |AND "test_float" >= 1.0 + |AND "test_int" <= 43 + |AND "test_int" IS NOT NULL + |AND "test_int" IS NULL + """.stripMargin.lines.mkString(" ").trim + // scalastyle:on + assert(whereClause === expectedWhereClause) + } + + private val testSchema: StructType = StructType(Seq( + StructField("test_byte", ByteType), + StructField("test_bool", BooleanType), + StructField("test_date", DateType), + StructField("test_double", DoubleType), + StructField("test_float", FloatType), + StructField("test_int", IntegerType), + StructField("test_long", LongType), + StructField("test_short", ShortType), + StructField("test_string", StringType), + StructField("test_timestamp", TimestampType))) + + /** A new filter subclasss which our pushdown logic does not know how to handle */ + private case object NewFilter extends Filter { + override def references: Array[String] = Array.empty + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala new file mode 100755 index 0000000000000..d665839fae5fc --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/MockRedshift.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} + +import scala.collection.mutable +import scala.util.matching.Regex + +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.Assertions._ + +import org.apache.spark.sql.types.StructType + +/** + * Helper class for mocking Redshift / JDBC in unit tests. + */ +class MockRedshift( + jdbcUrl: String, + existingTablesAndSchemas: Map[String, StructType], + jdbcQueriesThatShouldFail: Seq[Regex] = Seq.empty) { + + private[this] val queriesIssued: mutable.Buffer[String] = mutable.Buffer.empty + def getQueriesIssuedAgainstRedshift: Seq[String] = queriesIssued.toSeq + + private[this] val jdbcConnections: mutable.Buffer[Connection] = mutable.Buffer.empty + + val jdbcWrapper: JDBCWrapper = spy(new JDBCWrapper) + + private def createMockConnection(): Connection = { + val conn = mock(classOf[Connection], RETURNS_SMART_NULLS) + jdbcConnections.append(conn) + when(conn.prepareStatement(anyString())).thenAnswer(new Answer[PreparedStatement] { + override def answer(invocation: InvocationOnMock): PreparedStatement = { + val query = invocation.getArguments()(0).asInstanceOf[String] + queriesIssued.append(query) + val mockStatement = mock(classOf[PreparedStatement], RETURNS_SMART_NULLS) + if (jdbcQueriesThatShouldFail.forall(_.findFirstMatchIn(query).isEmpty)) { + when(mockStatement.execute()).thenReturn(true) + when(mockStatement.executeQuery()).thenReturn( + mock(classOf[ResultSet], RETURNS_SMART_NULLS)) + } else { + when(mockStatement.execute()).thenThrow(new SQLException(s"Error executing $query")) + when(mockStatement.executeQuery()).thenThrow(new SQLException(s"Error executing $query")) + } + mockStatement + } + }) + conn + } + + doAnswer(new Answer[Connection] { + override def answer(invocation: InvocationOnMock): Connection = createMockConnection() + }).when(jdbcWrapper) + .getConnector(any[Option[String]](), same(jdbcUrl), any[Option[(String, String)]]()) + + doAnswer(new Answer[Boolean] { + override def answer(invocation: InvocationOnMock): Boolean = { + existingTablesAndSchemas.contains(invocation.getArguments()(1).asInstanceOf[String]) + } + }).when(jdbcWrapper).tableExists(any[Connection], anyString()) + + doAnswer(new Answer[StructType] { + override def answer(invocation: InvocationOnMock): StructType = { + existingTablesAndSchemas(invocation.getArguments()(1).asInstanceOf[String]) + } + }).when(jdbcWrapper).resolveTable(any[Connection], anyString()) + + def verifyThatConnectionsWereClosed(): Unit = { + jdbcConnections.foreach { conn => + verify(conn).close() + } + } + + def verifyThatRollbackWasCalled(): Unit = { + jdbcConnections.foreach { conn => + verify(conn, atLeastOnce()).rollback() + } + } + + def verifyThatCommitWasNotCalled(): Unit = { + jdbcConnections.foreach { conn => + verify(conn, never()).commit() + } + } + + def verifyThatExpectedQueriesWereIssued(expectedQueries: Seq[Regex]): Unit = { + expectedQueries.zip(queriesIssued).foreach { case (expected, actual) => + if (expected.findFirstMatchIn(actual).isEmpty) { + fail( + s""" + |Actual and expected JDBC queries did not match: + |Expected: $expected + |Actual: $actual + """.stripMargin) + } + } + if (expectedQueries.length > queriesIssued.length) { + val missingQueries = expectedQueries.drop(queriesIssued.length) + fail(s"Missing ${missingQueries.length} expected JDBC queries:" + + s"\n${missingQueries.mkString("\n")}") + } else if (queriesIssued.length > expectedQueries.length) { + val extraQueries = queriesIssued.drop(expectedQueries.length) + fail(s"Got ${extraQueries.length} unexpected JDBC queries:\n${extraQueries.mkString("\n")}") + } + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala new file mode 100755 index 0000000000000..dea37a5a31d82 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala @@ -0,0 +1,145 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +/** + * Check validation of parameter config + */ +class ParametersSuite extends SparkFunSuite with Matchers { + + test("Minimal valid parameter map is accepted") { + val params = Map( + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_schema.test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password", + "forward_spark_s3_credentials" -> "true") + + val mergedParams = Parameters.mergeParameters(params) + + mergedParams.rootTempDir should startWith (params("tempdir")) + mergedParams.createPerQueryTempDir() should startWith (params("tempdir")) + mergedParams.jdbcUrl shouldBe params("url") + mergedParams.table shouldBe Some(TableName("test_schema", "test_table")) + assert(mergedParams.forwardSparkS3Credentials) + + // Check that the defaults have been added + (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach { + case (key, value) => mergedParams.parameters(key) shouldBe value + } + } + + test("createPerQueryTempDir() returns distinct temp paths") { + val params = Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password") + + val mergedParams = Parameters.mergeParameters(params) + + mergedParams.createPerQueryTempDir() should not equal mergedParams.createPerQueryTempDir() + } + + test("Errors are thrown when mandatory parameters are not provided") { + def checkMerge(params: Map[String, String], err: String): Unit = { + val e = intercept[IllegalArgumentException] { + Parameters.mergeParameters(params) + } + assert(e.getMessage.contains(err)) + } + val testURL = "jdbc:redshift://foo/bar?user=user&password=password" + checkMerge(Map("dbtable" -> "test_table", "url" -> testURL), "tempdir") + checkMerge(Map("tempdir" -> "s3://foo/bar", "url" -> testURL), "Redshift table name") + checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar"), "JDBC URL") + checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar", "url" -> testURL), + "method for authenticating") + } + + test("Must specify either 'dbtable' or 'query' parameter, but not both") { + intercept[IllegalArgumentException] { + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) + }.getMessage should (include ("dbtable") and include ("query")) + + intercept[IllegalArgumentException] { + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_table", + "query" -> "select * from test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) + }.getMessage should (include ("dbtable") and include ("query") and include("both")) + + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "query" -> "select * from test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) + } + + test("Must specify credentials in either URL or 'user' and 'password' parameters, but not both") { + intercept[IllegalArgumentException] { + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "query" -> "select * from test_table", + "url" -> "jdbc:redshift://foo/bar")) + }.getMessage should (include ("credentials")) + + intercept[IllegalArgumentException] { + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "query" -> "select * from test_table", + "user" -> "user", + "password" -> "password", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) + }.getMessage should (include ("credentials") and include("both")) + + Parameters.mergeParameters(Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "query" -> "select * from test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) + } + + test("tempformat option is case-insensitive") { + val params = Map( + "forward_spark_s3_credentials" -> "true", + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_schema.test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password") + + Parameters.mergeParameters(params + ("tempformat" -> "csv")) + Parameters.mergeParameters(params + ("tempformat" -> "CSV")) + + intercept[IllegalArgumentException] { + Parameters.mergeParameters(params + ("tempformat" -> "invalid-temp-format")) + } + } + + test("can only specify one Redshift to S3 authentication mechanism") { + val e = intercept[IllegalArgumentException] { + Parameters.mergeParameters(Map( + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_schema.test_table", + "url" -> "jdbc:redshift://foo/bar?user=user&password=password", + "forward_spark_s3_credentials" -> "true", + "aws_iam_role" -> "role")) + } + assert(e.getMessage.contains("mutually-exclusive")) + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala new file mode 100755 index 0000000000000..60533741e61ef --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/QueryTest.scala @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.plans.logical + +/** + * Copy of Spark SQL's `QueryTest` trait. + */ +trait QueryTest extends SparkFunSuite { + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (!isSorted) converted.sortBy(_.toString()) else converted + } + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + fail(errorMessage) + } + + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + |Results do not match for query: + |${df.queryExecution} + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + fail(errorMessage) + } + } + + private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { + val maxLeftSize = left.map(_.length).max + val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") + val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") + + leftPadded.zip(rightPadded).map { + case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r + } + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala new file mode 100755 index 0000000000000..aaf496a82cd11 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.language.implicitConversions + +import com.databricks.spark.redshift.RedshiftInputFormat._ +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types._ + +class RedshiftInputFormatSuite extends SparkFunSuite { + + import RedshiftInputFormatSuite._ + + private var sc: SparkContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sc = new SparkContext("local", this.getClass.getName) + } + + override def afterAll(): Unit = { + sc.stop() + super.afterAll() + } + + private def writeToFile(contents: String, file: File): Unit = { + val bytes = contents.getBytes + val out = new DataOutputStream(new FileOutputStream(file)) + out.write(bytes, 0, bytes.length) + out.close() + } + + private def escape(records: Set[Seq[String]], delimiter: Char): String = { + require(delimiter != '\\' && delimiter != '\n') + records.map { r => + r.map { f => + f.replace("\\", "\\\\") + .replace("\n", "\\\n") + .replace(delimiter, "\\" + delimiter) + }.mkString(delimiter) + }.mkString("", "\n", "\n") + } + + private final val KEY_BLOCK_SIZE = "fs.local.block.size" + + private final val TAB = '\t' + + private val records = Set( + Seq("a\n", DEFAULT_DELIMITER + "b\\"), + Seq("c", TAB + "d"), + Seq("\ne", "\\\\f")) + + private def withTempDir(func: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.deleteOnExit() + func(dir) + } + + test("default delimiter") { + withTempDir { dir => + val escaped = escape(records, DEFAULT_DELIMITER) + writeToFile(escaped, new File(dir, "part-00000")) + + val conf = new Configuration + conf.setLong(KEY_BLOCK_SIZE, 4) + + val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat], + classOf[java.lang.Long], classOf[Array[String]], conf) + + // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for + // assert(rdd.partitions.size > records.size) // so there exist at least one empty partition + + val actual = rdd.values.map(_.toSeq).collect() + assert(actual.size === records.size) + assert(actual.toSet === records) + } + } + + test("customized delimiter") { + withTempDir { dir => + val escaped = escape(records, TAB) + writeToFile(escaped, new File(dir, "part-00000")) + + val conf = new Configuration + conf.setLong(KEY_BLOCK_SIZE, 4) + conf.set(KEY_DELIMITER, TAB) + + val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat], + classOf[java.lang.Long], classOf[Array[String]], conf) + + // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for + // assert(rdd.partitions.size > records.size) // so there exist at least one empty partitions + + val actual = rdd.values.map(_.toSeq).collect() + assert(actual.size === records.size) + assert(actual.toSet === records) + } + } + + test("schema parser") { + withTempDir { dir => + val testRecords = Set( + Seq("a\n", "TX", 1, 1.0, 1000L, 200000000000L), + Seq("b", "CA", 2, 2.0, 2000L, 1231412314L)) + val escaped = escape(testRecords.map(_.map(_.toString)), DEFAULT_DELIMITER) + writeToFile(escaped, new File(dir, "part-00000")) + + val sqlContext = new SQLContext(sc) + val expectedSchema = StructType(Seq( + StructField("name", StringType, nullable = true), + StructField("state", StringType, nullable = true), + StructField("id", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("big_score", LongType, nullable = true), + StructField("some_long", LongType, nullable = true))) + + val df = sqlContext.redshiftFile(dir.toString, expectedSchema) + assert(df.schema === expectedSchema) + + val parsed = df.rdd.map { + case Row( + name: String, state: String, id: Int, score: Double, bigScore: Long, someLong: Long + ) => Seq(name, state, id, score, bigScore, someLong) + }.collect().toSet + + assert(parsed === testRecords) + } + } +} + +object RedshiftInputFormatSuite { + implicit def charToString(c: Char): String = c.toString +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala new file mode 100755 index 0000000000000..e039ad29ac59d --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -0,0 +1,576 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.io.{ByteArrayInputStream, OutputStreamWriter} +import java.net.URI + +import com.amazonaws.services.s3.AmazonS3Client +import com.amazonaws.services.s3.model.{BucketLifecycleConfiguration, S3Object, S3ObjectInputStream} +import com.amazonaws.services.s3.model.BucketLifecycleConfiguration.Rule +import com.databricks.spark.redshift.Parameters.MergedParameters +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.s3native.S3NInMemoryFileSystem +import org.apache.http.client.methods.HttpRequestBase +import org.mockito.Matchers._ +import org.mockito.Mockito +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * Tests main DataFrame loading and writing functionality + */ +class RedshiftSourceSuite + extends QueryTest + with Matchers + with BeforeAndAfterAll + with BeforeAndAfterEach { + + /** + * Spark Context with Hadoop file overridden to point at our local test data file for this suite, + * no matter what temp directory was generated and requested. + */ + private var sc: SparkContext = _ + + private var testSqlContext: SQLContext = _ + + private var expectedDataDF: DataFrame = _ + + private var mockS3Client: AmazonS3Client = _ + + private var s3FileSystem: FileSystem = _ + + private val s3TempDir: String = "s3n://test-bucket/temp-dir/" + + private var unloadedData: String = "" + + // Parameters common to most tests. Some parameters are overridden in specific tests. + private def defaultParams: Map[String, String] = Map( + "url" -> "jdbc:redshift://foo/bar?user=user&password=password", + "tempdir" -> s3TempDir, + "dbtable" -> "test_table", + "forward_spark_s3_credentials" -> "true") + + override def beforeAll(): Unit = { + super.beforeAll() + sc = new SparkContext("local", "RedshiftSourceSuite") + sc.hadoopConfiguration.set("fs.s3n.impl", classOf[S3NInMemoryFileSystem].getName) + // We need to use a DirectOutputCommitter to work around an issue which occurs with renames + // while using the mocked S3 filesystem. + sc.hadoopConfiguration.set("spark.sql.sources.outputCommitterClass", + classOf[DirectMapreduceOutputCommitter].getName) + sc.hadoopConfiguration.set("mapred.output.committer.class", + classOf[DirectMapredOutputCommitter].getName) + sc.hadoopConfiguration.set("fs.s3.awsAccessKeyId", "test1") + sc.hadoopConfiguration.set("fs.s3.awsSecretAccessKey", "test2") + sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", "test1") + sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", "test2") + } + + override def beforeEach(): Unit = { + super.beforeEach() + s3FileSystem = FileSystem.get(new URI(s3TempDir), sc.hadoopConfiguration) + testSqlContext = new SQLContext(sc) + expectedDataDF = + testSqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema) + // Configure a mock S3 client so that we don't hit errors when trying to access AWS in tests. + mockS3Client = Mockito.mock(classOf[AmazonS3Client], Mockito.RETURNS_SMART_NULLS) + when(mockS3Client.getBucketLifecycleConfiguration(anyString())).thenReturn( + new BucketLifecycleConfiguration().withRules( + new Rule().withPrefix("").withStatus(BucketLifecycleConfiguration.ENABLED) + )) + val mockManifest = Mockito.mock(classOf[S3Object], Mockito.RETURNS_SMART_NULLS) + when(mockManifest.getObjectContent).thenAnswer { + new Answer[S3ObjectInputStream] { + override def answer(invocationOnMock: InvocationOnMock): S3ObjectInputStream = { + val manifest = + s""" + | { + | "entries": [ + | { "url": "${Utils.fixS3Url(Utils.lastTempPathGenerated)}/part-00000" } + | ] + | } + """.stripMargin + // Write the data to the output file specified in the manifest: + val out = s3FileSystem.create(new Path(s"${Utils.lastTempPathGenerated}/part-00000")) + val ow = new OutputStreamWriter(out.getWrappedStream) + ow.write(unloadedData) + ow.close() + out.close() + val is = new ByteArrayInputStream(manifest.getBytes("UTF-8")) + new S3ObjectInputStream( + is, + Mockito.mock(classOf[HttpRequestBase], Mockito.RETURNS_SMART_NULLS)) + } + } + } + when(mockS3Client.getObject(anyString(), endsWith("manifest"))).thenReturn(mockManifest) + } + + override def afterEach(): Unit = { + super.afterEach() + testSqlContext = null + expectedDataDF = null + mockS3Client = null + FileSystem.closeAll() + } + + override def afterAll(): Unit = { + sc.stop() + super.afterAll() + } + + test("DefaultSource can load Redshift UNLOAD output to a DataFrame") { + // scalastyle:off + unloadedData = + """ + |1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001 + |1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0 + |0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00 + |0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123| + |||||||||| + """.stripMargin.trim + // scalastyle:on + val expectedQuery = ( + "UNLOAD \\('SELECT \"testbyte\", \"testbool\", \"testdate\", \"testdouble\"," + + " \"testfloat\", \"testint\", \"testlong\", \"testshort\", \"teststring\", " + + "\"testtimestamp\" " + + "FROM \"PUBLIC\".\"test_table\" '\\) " + + "TO '.*' " + + "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + + "ESCAPE").r + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) + + // Assert that we've loaded and converted all data in the test file + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + val relation = source.createRelation(testSqlContext, defaultParams) + val df = testSqlContext.baseRelationToDataFrame(relation) + checkAnswer(df, TestUtils.expectedData) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery)) + } + + test("Can load output of Redshift queries") { + // scalastyle:off + val expectedJDBCQuery = + """ + |UNLOAD \('SELECT "testbyte", "testbool" FROM + | \(select testbyte, testbool + | from test_table + | where teststring = \\'\\\\\\\\Unicode\\'\\'s樂趣\\'\) '\) + """.stripMargin.lines.map(_.trim).mkString(" ").trim.r + val query = + """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'""" + unloadedData = "1|t" + // scalastyle:on + val querySchema = + StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType))) + + val expectedValues = Array(Row(1.toByte, true)) + + // Test with dbtable parameter that wraps the query in parens: + { + val params = defaultParams + ("dbtable" -> s"($query)") + val mockRedshift = + new MockRedshift(defaultParams("url"), Map(params("dbtable") -> querySchema)) + val relation = new DefaultSource( + mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params) + assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedJDBCQuery)) + } + + // Test with query parameter + { + val params = defaultParams - "dbtable" + ("query" -> query) + val mockRedshift = new MockRedshift(defaultParams("url"), Map(s"($query)" -> querySchema)) + val relation = new DefaultSource( + mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params) + assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedJDBCQuery)) + } + } + + test("DefaultSource supports simple column filtering") { + // scalastyle:off + unloadedData = + """ + |1|t + |1|f + |0| + |0|f + || + """.stripMargin.trim + // scalastyle:on + val expectedQuery = ( + "UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " + + "TO '.*' " + + "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + + "ESCAPE").r + val mockRedshift = + new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema)) + // Construct the source with a custom schema + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + val relation = source.createRelation(testSqlContext, defaultParams, TestUtils.testSchema) + val resultSchema = + StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType))) + + val rdd = relation.asInstanceOf[PrunedFilteredScan] + .buildScan(Array("testbyte", "testbool"), Array.empty[Filter]) + .mapPartitions { iter => + val fromRow = RowEncoder(resultSchema).resolveAndBind().fromRow _ + iter.asInstanceOf[Iterator[InternalRow]].map(fromRow) + } + val prunedExpectedValues = Array( + Row(1.toByte, true), + Row(1.toByte, false), + Row(0.toByte, null), + Row(0.toByte, false), + Row(null, null)) + assert(rdd.collect() === prunedExpectedValues) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery)) + } + + test("DefaultSource supports user schema, pruned and filtered scans") { + // scalastyle:off + unloadedData = "1|t" + val expectedQuery = ( + "UNLOAD \\('SELECT \"testbyte\", \"testbool\" " + + "FROM \"PUBLIC\".\"test_table\" " + + "WHERE \"testbool\" = true " + + "AND \"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\' " + + "AND \"testdouble\" > 1000.0 " + + "AND \"testdouble\" < 1.7976931348623157E308 " + + "AND \"testfloat\" >= 1.0 " + + "AND \"testint\" <= 43'\\) " + + "TO '.*' " + + "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + + "ESCAPE").r + // scalastyle:on + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) + + // Construct the source with a custom schema + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + val relation = source.createRelation(testSqlContext, defaultParams, TestUtils.testSchema) + val resultSchema = + StructType(Seq(StructField("testbyte", ByteType), StructField("testbool", BooleanType))) + + // Define a simple filter to only include a subset of rows + val filters: Array[Filter] = Array( + EqualTo("testbool", true), + // scalastyle:off + EqualTo("teststring", "Unicode's樂趣"), + // scalastyle:on + GreaterThan("testdouble", 1000.0), + LessThan("testdouble", Double.MaxValue), + GreaterThanOrEqual("testfloat", 1.0f), + LessThanOrEqual("testint", 43)) + val rdd = relation.asInstanceOf[PrunedFilteredScan] + .buildScan(Array("testbyte", "testbool"), filters) + .mapPartitions { iter => + val fromRow = RowEncoder(resultSchema).resolveAndBind().fromRow _ + iter.asInstanceOf[Iterator[InternalRow]].map(fromRow) + } + + assert(rdd.collect() === Array(Row(1, true))) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq(expectedQuery)) + } + + test("DefaultSource supports preactions options to run queries before running COPY command") { + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + val params = defaultParams ++ Map( + "preactions" -> + """ + | DELETE FROM %s WHERE id < 100; + | DELETE FROM %s WHERE id > 100; + | DELETE FROM %s WHERE id = -1; + """.stripMargin.trim, + "usestagingtable" -> "true") + + val expectedCommands = Seq( + "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table.*\"".r, + "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table.*\"".r, + "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id < 100".r, + "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id > 100".r, + "DELETE FROM \"PUBLIC\".\"test_table.*\" WHERE id = -1".r, + "COPY \"PUBLIC\".\"test_table.*\"".r) + + source.createRelation(testSqlContext, SaveMode.Overwrite, params, expectedDataDF) + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + mockRedshift.verifyThatConnectionsWereClosed() + } + + test("DefaultSource serializes data as Avro, then sends Redshift COPY command") { + val params = defaultParams ++ Map( + "postactions" -> "GRANT SELECT ON %s TO jeremy", + "diststyle" -> "KEY", + "distkey" -> "testint") + + val expectedCommands = Seq( + "DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table.*\"".r, + ("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table.*" + + " DISTSTYLE KEY DISTKEY \\(testint\\).*").r, + "COPY \"PUBLIC\"\\.\"test_table.*\"".r, + "GRANT SELECT ON \"PUBLIC\"\\.\"test_table\" TO jeremy".r) + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) + + val relation = RedshiftRelation( + mockRedshift.jdbcWrapper, + _ => mockS3Client, + Parameters.mergeParameters(params), + userSchema = None)(testSqlContext) + relation.asInstanceOf[InsertableRelation].insert(expectedDataDF, overwrite = true) + + // Make sure we wrote the data out ready for Redshift load, in the expected formats. + // The data should have been written to a random subdirectory of `tempdir`. Since we clear + // `tempdir` between every unit test, there should only be one directory here. + assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1) + val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString + val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles) + checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + + test("Cannot write table with column names that become ambiguous under case insensitivity") { + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) + + val schema = StructType(Seq(StructField("a", IntegerType), StructField("A", IntegerType))) + val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) + val writer = new RedshiftWriter(mockRedshift.jdbcWrapper, _ => mockS3Client) + + intercept[IllegalArgumentException] { + writer.saveToRedshift( + testSqlContext, df, SaveMode.Append, Parameters.mergeParameters(defaultParams)) + } + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatCommitWasNotCalled() + mockRedshift.verifyThatRollbackWasCalled() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty) + } + + test("Failed copies are handled gracefully when using a staging table") { + val params = defaultParams ++ Map("usestagingtable" -> "true") + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema), + jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table.*\"".r)) + + val expectedCommands = Seq( + "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table.*\"".r, + "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table.*\"".r, + "COPY \"PUBLIC\".\"test_table.*\"".r, + ".*FROM stl_load_errors.*".r + ) + + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + intercept[Exception] { + source.createRelation(testSqlContext, SaveMode.Overwrite, params, expectedDataDF) + } + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatCommitWasNotCalled() + mockRedshift.verifyThatRollbackWasCalled() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + + test("Append SaveMode doesn't destroy existing data") { + val expectedCommands = + Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r, + "COPY \"PUBLIC\".\"test_table\" .*".r) + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) + + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + source.createRelation(testSqlContext, SaveMode.Append, defaultParams, expectedDataDF) + + // This test is "appending" to an empty table, so we expect all our test data to be + // the only content in the returned data frame. + // The data should have been written to a random subdirectory of `tempdir`. Since we clear + // `tempdir` between every unit test, there should only be one directory here. + assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1) + val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString + val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles) + checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + + test("configuring maxlength on string columns") { + val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build() + val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build() + val schema = StructType( + StructField("long_str", StringType, metadata = longStrMetadata) :: + StructField("short_str", StringType, metadata = shortStrMetadata) :: + StructField("default_str", StringType) :: + Nil) + val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) + val createTableCommand = + DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim + val expectedCreateTableCommand = + """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("long_str" VARCHAR(512),""" + + """ "short_str" VARCHAR(10), "default_str" TEXT)""" + assert(createTableCommand === expectedCreateTableCommand) + } + + test("configuring encoding on columns") { + val lzoMetadata = new MetadataBuilder().putString("encoding", "LZO").build() + val runlengthMetadata = new MetadataBuilder().putString("encoding", "RUNLENGTH").build() + val schema = StructType( + StructField("lzo_str", StringType, metadata = lzoMetadata) :: + StructField("runlength_str", StringType, metadata = runlengthMetadata) :: + StructField("default_str", StringType) :: + Nil) + val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) + val createTableCommand = + DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim + val expectedCreateTableCommand = + """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("lzo_str" TEXT ENCODE LZO,""" + + """ "runlength_str" TEXT ENCODE RUNLENGTH, "default_str" TEXT)""" + assert(createTableCommand === expectedCreateTableCommand) + } + + test("configuring descriptions on columns") { + val descriptionMetadata1 = new MetadataBuilder().putString("description", "Test1").build() + val descriptionMetadata2 = new MetadataBuilder().putString("description", "Test'2").build() + val schema = StructType( + StructField("first_str", StringType, metadata = descriptionMetadata1) :: + StructField("second_str", StringType, metadata = descriptionMetadata2) :: + StructField("default_str", StringType) :: + Nil) + val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) + val commentCommands = + DefaultRedshiftWriter.commentActions(Some("Test"), schema) + val expectedCommentCommands = List( + "COMMENT ON TABLE %s IS 'Test'", + "COMMENT ON COLUMN %s.\"first_str\" IS 'Test1'", + "COMMENT ON COLUMN %s.\"second_str\" IS 'Test''2'") + assert(commentCommands === expectedCommentCommands) + } + + test("configuring redshift_type on columns") { + val bpcharMetadata = new MetadataBuilder().putString("redshift_type", "BPCHAR(2)").build() + val nvarcharMetadata = new MetadataBuilder().putString("redshift_type", "NVARCHAR(123)").build() + + val schema = StructType( + StructField("bpchar_str", StringType, metadata = bpcharMetadata) :: + StructField("bpchar_str", StringType, metadata = nvarcharMetadata) :: + StructField("default_str", StringType) :: + Nil) + + val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) + val createTableCommand = + DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim + val expectedCreateTableCommand = + """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("bpchar_str" BPCHAR(2),""" + + """ "bpchar_str" NVARCHAR(123), "default_str" TEXT)""" + assert(createTableCommand === expectedCreateTableCommand) + } + + test("Respect SaveMode.ErrorIfExists when table exists") { + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) + val errIfExistsSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + intercept[Exception] { + errIfExistsSource.createRelation( + testSqlContext, SaveMode.ErrorIfExists, defaultParams, expectedDataDF) + } + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty) + } + + test("Do nothing when table exists if SaveMode = Ignore") { + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) + val ignoreSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + ignoreSource.createRelation(testSqlContext, SaveMode.Ignore, defaultParams, expectedDataDF) + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty) + } + + test("Cannot save when 'query' parameter is specified instead of 'dbtable'") { + val invalidParams = Map( + "url" -> "jdbc:redshift://foo/bar?user=user&password=password", + "tempdir" -> s3TempDir, + "query" -> "select * from test_table", + "forward_spark_s3_credentials" -> "true") + + val e1 = intercept[IllegalArgumentException] { + expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save() + } + assert(e1.getMessage.contains("dbtable")) + } + + test("Public Scala API rejects invalid parameter maps") { + val invalidParams = Map("dbtable" -> "foo") // missing tempdir and url + + val e1 = intercept[IllegalArgumentException] { + expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save() + } + assert(e1.getMessage.contains("tempdir")) + + val e2 = intercept[IllegalArgumentException] { + expectedDataDF.write.format("com.databricks.spark.redshift").options(invalidParams).save() + } + assert(e2.getMessage.contains("tempdir")) + } + + test("DefaultSource has default constructor, required by Data Source API") { + new DefaultSource() + } + + test("Saves throw error message if S3 Block FileSystem would be used") { + val params = defaultParams + ("tempdir" -> defaultParams("tempdir").replace("s3n", "s3")) + val e = intercept[IllegalArgumentException] { + expectedDataDF.write + .format("com.databricks.spark.redshift") + .mode("append") + .options(params) + .save() + } + assert(e.getMessage.contains("Block FileSystem")) + } + + test("Loads throw error message if S3 Block FileSystem would be used") { + val params = defaultParams + ("tempdir" -> defaultParams("tempdir").replace("s3n", "s3")) + val e = intercept[IllegalArgumentException] { + testSqlContext.read.format("com.databricks.spark.redshift").options(params).load() + } + assert(e.getMessage.contains("Block FileSystem")) + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala new file mode 100755 index 0000000000000..bef296d5e5de0 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/SerializableConfigurationSuite.scala @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableConfigurationSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + val conf = new SerializableConfiguration(new Configuration()) + + val serialized = serializer.serialize(conf) + + serializer.deserialize[Any](serialized) match { + case c: SerializableConfiguration => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + case other => fail( + s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } + +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala new file mode 100755 index 0000000000000..6ffd0a1232624 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import org.apache.spark.SparkFunSuite + +class TableNameSuite extends SparkFunSuite { + test("TableName.parseFromEscaped") { + assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar")) + assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo")) + assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo")) + assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar")) + // Dots (.) can also appear inside of valid identifiers. + assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz")) + assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("foo\".bar", "baz")) + } + + test("TableName.toString") { + assert(TableName("foo", "bar").toString === """"foo"."bar"""") + assert(TableName("PUBLIC", "bar").toString === """"PUBLIC"."bar"""") + assert(TableName("\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"") + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala new file mode 100755 index 0000000000000..53f837b5e0df8 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/TestUtils.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.sql.{Date, Timestamp} +import java.util.{Calendar, Locale} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +/** + * Helpers for Redshift tests that require common mocking + */ +object TestUtils { + + /** + * Simple schema that includes all data types we support + */ + val testSchema: StructType = { + // These column names need to be lowercase; see #51 + StructType(Seq( + StructField("testbyte", ByteType), + StructField("testbool", BooleanType), + StructField("testdate", DateType), + StructField("testdouble", DoubleType), + StructField("testfloat", FloatType), + StructField("testint", IntegerType), + StructField("testlong", LongType), + StructField("testshort", ShortType), + StructField("teststring", StringType), + StructField("testtimestamp", TimestampType))) + } + + // scalastyle:off + /** + * Expected parsed output corresponding to the output of testData. + */ + val expectedData: Seq[Row] = Seq( + Row(1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498, + 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣", + TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)), + Row(1.toByte, false, TestUtils.toDate(2015, 6, 2), 0.0, 0.0f, 42, + 1239012341823719L, -13.toShort, "asdf", TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0, 0)), + Row(0.toByte, null, TestUtils.toDate(2015, 6, 3), 0.0, -1.0f, 4141214, + 1239012341823719L, null, "f", TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)), + Row(0.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 24.toShort, + "___|_123", null), + Row(List.fill(10)(null): _*)) + // scalastyle:on + + /** + * The same as `expectedData`, but with dates and timestamps converted into string format. + * See #39 for context. + */ + val expectedDataWithConvertedTimesAndDates: Seq[Row] = expectedData.map { row => + Row.fromSeq(row.toSeq.map { + case t: Timestamp => Conversions.createRedshiftTimestampFormat().format(t) + case d: Date => Conversions.createRedshiftDateFormat().format(d) + case other => other + }) + } + + /** + * Convert date components to a millisecond timestamp + */ + def toMillis( + year: Int, + zeroBasedMonth: Int, + date: Int, + hour: Int, + minutes: Int, + seconds: Int, + millis: Int = 0): Long = { + val calendar = Calendar.getInstance() + calendar.set(year, zeroBasedMonth, date, hour, minutes, seconds) + calendar.set(Calendar.MILLISECOND, millis) + calendar.getTime.getTime + } + + /** + * Convert date components to a SQL Timestamp + */ + def toTimestamp( + year: Int, + zeroBasedMonth: Int, + date: Int, + hour: Int, + minutes: Int, + seconds: Int, + millis: Int = 0): Timestamp = { + new Timestamp(toMillis(year, zeroBasedMonth, date, hour, minutes, seconds, millis)) + } + + /** + * Convert date components to a SQL [[Date]]. + */ + def toDate(year: Int, zeroBasedMonth: Int, date: Int): Date = { + new Date(toTimestamp(year, zeroBasedMonth, date, 0, 0, 0).getTime) + } + + def withDefaultLocale[T](newDefaultLocale: Locale)(block: => T): T = { + val originalDefaultLocale = Locale.getDefault + try { + Locale.setDefault(newDefaultLocale) + block + } finally { + Locale.setDefault(originalDefaultLocale) + } + } +} diff --git a/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala b/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala new file mode 100755 index 0000000000000..20a9180fe3405 --- /dev/null +++ b/external/redshift/src/test/scala/com/databricks/spark/redshift/UtilsSuite.scala @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2016 Databricks, Inc. + * + * Portions of this software incorporate or are derived from software contained within Apache Spark, + * and this modified software differs from the Apache Spark software provided under the Apache + * License, Version 2.0, a copy of which you may obtain at + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package com.databricks.spark.redshift + +import java.net.URI + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for helper functions + */ +class UtilsSuite extends SparkFunSuite with Matchers { + + test("joinUrls preserves protocol information") { + Utils.joinUrls("s3n://foo/bar/", "/baz") shouldBe "s3n://foo/bar/baz/" + Utils.joinUrls("s3n://foo/bar/", "/baz/") shouldBe "s3n://foo/bar/baz/" + Utils.joinUrls("s3n://foo/bar/", "baz/") shouldBe "s3n://foo/bar/baz/" + Utils.joinUrls("s3n://foo/bar/", "baz") shouldBe "s3n://foo/bar/baz/" + Utils.joinUrls("s3n://foo/bar", "baz") shouldBe "s3n://foo/bar/baz/" + } + + test("joinUrls preserves credentials") { + assert( + Utils.joinUrls("s3n://ACCESSKEY:SECRETKEY@bucket/tempdir", "subdir") === + "s3n://ACCESSKEY:SECRETKEY@bucket/tempdir/subdir/") + } + + test("fixUrl produces Redshift-compatible equivalents") { + Utils.fixS3Url("s3a://foo/bar/12345") shouldBe "s3://foo/bar/12345" + Utils.fixS3Url("s3n://foo/bar/baz") shouldBe "s3://foo/bar/baz" + } + + test("addEndpointToUrl produces urls with endpoints added to host") { + Utils.addEndpointToUrl("s3a://foo/bar/12345") shouldBe "s3a://foo.s3.amazonaws.com/bar/12345" + Utils.addEndpointToUrl("s3n://foo/bar/baz") shouldBe "s3n://foo.s3.amazonaws.com/bar/baz" + } + + test("temp paths are random subdirectories of root") { + val root = "s3n://temp/" + val firstTempPath = Utils.makeTempPath(root) + + Utils.makeTempPath(root) should (startWith (root) and endWith ("/") + and not equal root and not equal firstTempPath) + } + + test("removeCredentialsFromURI removes AWS access keys") { + def removeCreds(uri: String): String = { + Utils.removeCredentialsFromURI(URI.create(uri)).toString + } + assert(removeCreds("s3n://bucket/path/to/temp/dir") === "s3n://bucket/path/to/temp/dir") + assert( + removeCreds("s3n://ACCESSKEY:SECRETKEY@bucket/path/to/temp/dir") === + "s3n://bucket/path/to/temp/dir") + } + + test("getRegionForRedshiftCluster") { + val redshiftUrl = + "jdbc:redshift://example.secret.us-west-2.redshift.amazonaws.com:5439/database" + assert(Utils.getRegionForRedshiftCluster("mycluster.example.com") === None) + assert(Utils.getRegionForRedshiftCluster(redshiftUrl) === Some("us-west-2")) + } +} diff --git a/pom.xml b/pom.xml index d269a4fcbeeb7..9e19ae588ef9c 100644 --- a/pom.xml +++ b/pom.xml @@ -108,6 +108,8 @@ repl launcher external/avro + external/redshift + external/redshift-integration-tests external/kafka-0-8 external/kafka-0-8-assembly external/kafka-0-10 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 21e84021105a0..8dfba47145413 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,8 +39,9 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, sqlKafka08, avro) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "sql-kafka-0-8", "avro" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, sqlKafka08, avro, redshift, redshiftIntegrationTests) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "sql-kafka-0-8", "avro", + "redshift", "redshift-integration-tests" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq( @@ -353,7 +354,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, sqlKafka08, avro + unsafe, tags, sqlKafka010, sqlKafka08, avro, redshift, redshiftIntegrationTests ).contains(x) } @@ -717,9 +718,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro, redshift), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010, sqlKafka08, avro, redshift), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value)