diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0e38f224ac81d..642a12c1edf6c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet} import scala.reflect.ClassTag -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.util.NextIterator +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index = idx @@ -125,5 +128,82 @@ object JdbcRDD { def resultSetToObjectArray(rs: ResultSet): Array[Object] = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } -} + trait ConnectionFactory extends Serializable { + @throws[Exception] + def getConnection: Connection + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ + def create[T]( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: JFunction[ResultSet, T]): JavaRDD[T] = { + + val jdbcRDD = new JdbcRDD[T]( + sc.sc, + () => connectionFactory.getConnection, + sql, + lowerBound, + upperBound, + numPartitions, + (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag) + + new JavaRDD[T](jdbcRDD)(fakeClassTag) + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is + * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + */ + def create( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): JavaRDD[Array[Object]] = { + + val mapRow = new JFunction[ResultSet, Array[Object]] { + override def call(resultSet: ResultSet): Array[Object] = { + resultSetToObjectArray(resultSet) + } + } + + create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow) + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 59c86eecac5e8..67a660f662da5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -18,13 +18,18 @@ package org.apache.spark; import java.io.*; -import java.nio.channels.FileChannel; -import java.nio.ByteBuffer; import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; import java.util.*; import java.util.concurrent.*; -import org.apache.spark.input.PortableDataStream; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -51,8 +56,10 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; +import org.apache.spark.rdd.JdbcRDD; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; @@ -1508,4 +1515,77 @@ public void testRegisterKryoClasses() { conf.get("spark.kryo.classesToRegister")); } + + private void setUpJdbc() throws Exception { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); + Connection connection = + DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true"); + + try { + Statement create = connection.createStatement(); + create.execute( + "CREATE TABLE FOO(" + + "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + + "DATA INTEGER)"); + create.close(); + + PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)"); + for (int i = 1; i <= 100; i++) { + insert.setInt(i, i * 2); + insert.executeUpdate(); + } + } catch (SQLException e) { + // If table doesn't exist... + if (e.getSQLState().compareTo("X0Y32") != 0) { + throw e; + } + } finally { + connection.close(); + } + } + + private void tearDownJdbc() throws SQLException { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true"); + } catch(SQLException e) { + if (e.getSQLState().compareTo("XJ015") != 0) { + throw e; + } + } + } + + @Test + public void testJavaJdbcRDD() throws Exception { + setUpJdbc(); + + try { + JavaRDD rdd = JdbcRDD.create( + sc, + new JdbcRDD.ConnectionFactory() { + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb"); + } + }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + new Function() { + @Override + public Integer call(ResultSet r) throws Exception { + return r.getInt(1); + } + } + ).cache(); + + Assert.assertEquals(rdd.count(), 100); + Assert.assertEquals(rdd.reduce(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }), Integer.valueOf(10100)); + } finally { + tearDownJdbc(); + } + } }