Skip to content

Commit

Permalink
MapR [SPARK-128] MapRDB connector - wrong handle of null fields when …
Browse files Browse the repository at this point in the history
…nullable is false (apache#194)

(cherry picked from commit 2a8a6c1)
  • Loading branch information
ekrivokonmapr authored and Mikhail Gorbov committed Jan 2, 2018
1 parent 26cb465 commit b2dd7e6
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.mapr.db.spark.configuration.SerializableConfiguration
import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.utils.{LoggingTrait, MapRDBUtils}
import com.mapr.db.spark.writers._

import org.apache.hadoop.conf.Configuration
import org.ojai.{Document, DocumentConstants, Value}
import org.ojai.store.DocumentMutation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.db.spark.RDD

import scala.language.existentials
import scala.reflect.ClassTag

import com.mapr.db.exceptions.DBException
import com.mapr.db.spark.condition.{DBQueryCondition, Predicate}
import com.mapr.db.spark.dbclient.DBClient
import org.apache.spark.rdd.RDD
import org.ojai.store.QueryCondition

import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.ojai.store.QueryCondition

private[spark] abstract class MapRDBBaseRDD[T: ClassTag](
@transient val sc: SparkContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package com.mapr.db.spark.RDD
import scala.language.existentials
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._

import com.mapr.db.impl.{ConditionImpl, IdCodec}
import com.mapr.db.spark.RDD.partition.MaprDBPartition
import com.mapr.db.spark.RDD.partitioner.MapRDBPartitioner
Expand All @@ -14,14 +13,10 @@ import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.impl.OJAIDocument
import com.mapr.db.spark.utils.DefaultClass.DefaultType
import com.mapr.db.spark.utils.MapRSpark

import org.ojai.{Document, Value}

import org.apache.spark.{Partition, Partitioner, SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession


import org.apache.spark.sql.{DataFrame, SparkSession}

private[spark] class MapRDBTableScanRDD[T: ClassTag](
@transient sparkSession: SparkSession,
Expand All @@ -36,9 +31,9 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](

@transient private lazy val table = DBClient().getTable(tableName)
@transient private lazy val tabletinfos =
if (condition == null || condition.condition.isEmpty)
if (condition == null || condition.condition.isEmpty) {
DBClient().getTabletInfos(tableName)
else DBClient().getTabletInfos(tableName, condition.condition)
} else DBClient().getTabletInfos(tableName, condition.condition)
@transient private lazy val getSplits: Seq[Value] = {
val keys = tabletinfos.map(
tableinfo =>
Expand All @@ -52,18 +47,18 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](
}

private def getPartitioner: Partitioner = {
if (getSplits.isEmpty)
return null
if (getSplits(0).getType == Value.Type.STRING) {
return MapRDBPartitioner(getSplits.map(_.getString))
if (getSplits.isEmpty) {
null
} else if (getSplits(0).getType == Value.Type.STRING) {
MapRDBPartitioner(getSplits.map(_.getString))
} else {
return MapRDBPartitioner(getSplits.map(_.getBinary))
MapRDBPartitioner(getSplits.map(_.getBinary))
}
}

def toDF[T <: Product: TypeTag]() = maprspark[T]()
def toDF[T <: Product: TypeTag](): DataFrame = maprspark[T]()

def maprspark[T <: Product: TypeTag]() = {
def maprspark[T <: Product: TypeTag](): DataFrame = {
MapRSpark.builder
.sparkSession(sparkSession)
.configuration()
Expand All @@ -89,7 +84,7 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](
DBQueryCondition(tabcond)).asInstanceOf[Partition]
})
logDebug("Partitions for the table:" + tableName + " are " + splits)
return splits.toArray
splits.toArray
}

override def getPreferredLocations(split: Partition): Seq[String] = {
Expand Down Expand Up @@ -125,8 +120,9 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](
if (columns != null) {
logDebug("Columns projected from table:" + columns)
itrs = table.find(combinedCond.build(), columns.toArray: _*).iterator()
} else
} else {
itrs = table.find(combinedCond.build()).iterator()
}
val ojaiCursor = reqType.getValue(itrs, beanClass)

context.addTaskCompletionListener((ctx: TaskContext) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.db.spark.RDD.partition

import org.apache.spark.Partition
import com.mapr.db.spark.condition.DBQueryCondition

import org.apache.spark.Partition
/**
* An identifier for a partition in a MapRTableScanRDD.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import com.mapr.db.impl.ConditionImpl
import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.utils.MapRSpark
import org.ojai.DocumentConstants
import org.ojai.store.QueryCondition

import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.SaveMode._
import org.ojai.store.QueryCondition

class DefaultSource
extends DataSourceRegister
Expand Down Expand Up @@ -156,7 +156,6 @@ class DefaultSource
val failureOnConflict = failOnConflict.toBoolean

val rdd = MapRSpark.builder
.sparkContext(sqlContext.sparkContext)
.sparkSession(sqlContext.sparkSession)
.configuration()
.setTable(tableName.get)
Expand All @@ -165,15 +164,21 @@ class DefaultSource
.build
.toRDD(null)

val schema: StructType = userSchema match {
val schema: StructType = makeSchemaNullable(userSchema match {
case Some(s) => s
case None =>
GenerateSchema(
rdd,
sampleSize.map(_.toDouble).getOrElse(GenerateSchema.SAMPLE_SIZE),
failureOnConflict)
}
})

MapRDBRelation(tableName.get, schema, rdd, Operation)(sqlContext)
}

private def makeSchemaNullable(schema: StructType): StructType = {
StructType(schema.map(field => {
StructField(field.name, field.dataType, nullable = true, field.metadata )
}))
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.db.spark.sql

import com.mapr.db.spark.RDD.MapRDBBaseRDD
import com.mapr.db.spark.impl.OJAIDocument
import com.mapr.db.spark.utils.MapRSpark
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.ojai.DocumentReader
import java.util.Arrays.sort
import java.util.Comparator

import scala.reflect.runtime.universe._
import scala.Array._
import java.util.Arrays.sort
import java.util.Comparator

import com.mapr.db.spark.exceptions.SchemaMappingException
import com.mapr.db.spark.RDD.MapRDBBaseRDD
import com.mapr.db.spark.impl.OJAIDocument
import com.mapr.db.spark.utils.MapRSpark
import org.ojai.DocumentReader

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

import scala.annotation.switch

object GenerateSchema {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.db.spark.streaming

import com.mapr.db.spark._
import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.utils.LoggingTrait
import org.apache.spark.SparkContext
import org.apache.spark.streaming.dstream.DStream
import org.ojai.DocumentConstants
import com.mapr.db.spark._
import com.mapr.db.spark.writers.OJAIValue
import org.ojai.DocumentConstants

import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.SparkContext

class DStreamFunctions[T](dStream: DStream[T])(implicit fv: OJAIValue[T])
extends Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.mapr.db.spark.types

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.{ByteBuffer, ByteOrder}

import com.mapr.db.impl.IdCodec
import com.mapr.db.spark.utils.MapRDBUtils
import com.mapr.db.util.ByteBufs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio._
import java.util

import scala.language.implicitConversions
import scala.collection.JavaConverters._
import scala.collection.MapLike
import scala.language.implicitConversions

import com.mapr.db.rowcol.RowcolCodec
import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.utils.MapRDBUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
package com.mapr.db.spark.writers

import java.nio.ByteBuffer
import org.ojai.Document

import com.mapr.db.mapreduce.BulkLoadRecordWriter
import com.mapr.db.rowcol.DBValueBuilderImpl
import org.ojai.Document

private[spark] case class BulkTableWriter(@transient table: BulkLoadRecordWriter) extends Writer {

Expand Down

0 comments on commit b2dd7e6

Please sign in to comment.