Skip to content

Commit

Permalink
Java API for JdbcRDD
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Nov 26, 2014
1 parent bf1a6aa commit ffcdf2e
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 5 deletions.
84 changes: 82 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}

import scala.reflect.ClassTag

import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}

private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
Expand Down Expand Up @@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
}

trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}

/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {

val jdbcRDD = new JdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
lowerBound,
upperBound,
numPartitions,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)

new JavaRDD[T](jdbcRDD)(fakeClassTag)
}

/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*/
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int): JavaRDD[Array[Object]] = {

val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}

create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
}
}
86 changes: 83 additions & 3 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
package org.apache.spark;

import java.io.*;
import java.nio.channels.FileChannel;
import java.nio.ByteBuffer;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.concurrent.*;

import org.apache.spark.input.PortableDataStream;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
Expand All @@ -51,8 +56,10 @@
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
import org.apache.spark.rdd.JdbcRDD;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.StatCounter;

Expand Down Expand Up @@ -1508,4 +1515,77 @@ public void testRegisterKryoClasses() {
conf.get("spark.kryo.classesToRegister"));
}


private void setUpJdbc() throws Exception {
Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
Connection connection =
DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true");

try {
Statement create = connection.createStatement();
create.execute(
"CREATE TABLE FOO(" +
"ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
"DATA INTEGER)");
create.close();

PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
for (int i = 1; i <= 100; i++) {
insert.setInt(i, i * 2);
insert.executeUpdate();
}
} catch (SQLException e) {
// If table doesn't exist...
if (e.getSQLState().compareTo("X0Y32") != 0) {
throw e;
}
} finally {
connection.close();
}
}

private void tearDownJdbc() throws SQLException {
try {
DriverManager.getConnection("jdbc:derby:;shutdown=true");
} catch(SQLException e) {
if (e.getSQLState().compareTo("XJ015") != 0) {
throw e;
}
}
}

@Test
public void testJavaJdbcRDD() throws Exception {
setUpJdbc();

try {
JavaRDD<Integer> rdd = JdbcRDD.create(
sc,
new JdbcRDD.ConnectionFactory() {
@Override
public Connection getConnection() throws SQLException {
return DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb");
}
},
"SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
1, 100, 3,
new Function<ResultSet, Integer>() {
@Override
public Integer call(ResultSet r) throws Exception {
return r.getInt(1);
}
}
).cache();

Assert.assertEquals(rdd.count(), 100);
Assert.assertEquals(rdd.reduce(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer i1, Integer i2) {
return i1 + i2;
}
}), Integer.valueOf(10100));
} finally {
tearDownJdbc();
}
}
}

0 comments on commit ffcdf2e

Please sign in to comment.