Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3412] [SQL] Add 3 missing types for Row API #2284

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
Expand Down Expand Up @@ -137,6 +138,15 @@ class JoinedRow extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def getDecimal(i: Int): BigDecimal =
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)

def getTimestamp(i: Int): Timestamp =
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)

def getBinary(i: Int): Array[Byte] =
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
Expand Down Expand Up @@ -226,6 +236,15 @@ class JoinedRow2 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def getDecimal(i: Int): BigDecimal =
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)

def getTimestamp(i: Int): Timestamp =
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)

def getBinary(i: Int): Array[Byte] =
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
Expand Down Expand Up @@ -309,6 +328,15 @@ class JoinedRow3 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def getDecimal(i: Int): BigDecimal =
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)

def getTimestamp(i: Int): Timestamp =
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)

def getBinary(i: Int): Array[Byte] =
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
Expand Down Expand Up @@ -392,6 +420,15 @@ class JoinedRow4 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def getDecimal(i: Int): BigDecimal =
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)

def getTimestamp(i: Int): Timestamp =
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)

def getBinary(i: Int): Array[Byte] =
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
Expand Down Expand Up @@ -475,6 +512,15 @@ class JoinedRow5 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def getDecimal(i: Int): BigDecimal =
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)

def getTimestamp(i: Int): Timestamp =
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)

def getBinary(i: Int): Array[Byte] =
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.types.NativeType

object Row {
Expand Down Expand Up @@ -64,6 +66,9 @@ trait Row extends Seq[Any] with Serializable {
def getShort(i: Int): Short
def getByte(i: Int): Byte
def getString(i: Int): String
def getDecimal(i: Int): BigDecimal
def getTimestamp(i: Int): Timestamp
def getBinary(i: Int): Array[Byte]

override def toString() =
s"[${this.mkString(",")}]"
Expand Down Expand Up @@ -98,6 +103,9 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
def setDecimal(ordinal: Int, value: BigDecimal)
def setTimestamp(ordinal: Int, value: Timestamp)
def setBinary(ordinal: Int, value: Array[Byte])
}

/**
Expand All @@ -118,6 +126,9 @@ object EmptyRow extends Row {
def getShort(i: Int): Short = throw new UnsupportedOperationException
def getByte(i: Int): Byte = throw new UnsupportedOperationException
def getString(i: Int): String = throw new UnsupportedOperationException
def getDecimal(i: Int): BigDecimal = throw new UnsupportedOperationException
def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException
def getBinary(i: Int): Array[Byte] = throw new UnsupportedOperationException

def copy() = this
}
Expand Down Expand Up @@ -181,6 +192,21 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
values(i).asInstanceOf[String]
}

def getDecimal(i: Int): BigDecimal = {
if (values(i) == null) sys.error("Failed to check null bit for primitive Decimal value.")
values(i).asInstanceOf[BigDecimal]
}

def getTimestamp(i: Int): Timestamp = {
if (values(i) == null) sys.error("Failed to check null bit for primitive Timestamp value.")
values(i).asInstanceOf[Timestamp]
}

def getBinary(i: Int): Array[Byte] = {
if (values(i) == null) sys.error("Failed to check null bit for primitive Binary value.")
values(i).asInstanceOf[Array[Byte]]
}

// Custom hashCode function that matches the efficient code generated version.
override def hashCode(): Int = {
var result: Int = 37
Expand All @@ -201,6 +227,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case b: Array[Byte] => 123 // TODO need to figure out how to compute the hashcode
case other => other.hashCode()
}
}
Expand All @@ -224,6 +251,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
override def setDecimal(ordinal: Int, value: BigDecimal): Unit = { values(ordinal) = value }
override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { values(ordinal) = value }
override def setBinary(ordinal: Int, value: Array[Byte]): Unit = { values(ordinal) = value }

override def setNullAt(i: Int): Unit = { values(i) = null }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.types._

/**
Expand Down Expand Up @@ -231,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR

override def iterator: Iterator[Any] = values.map(_.boxed).iterator

def setString(ordinal: Int, value: String) = update(ordinal, value)
override def setString(ordinal: Int, value: String) = update(ordinal, value)

def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]

override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
Expand Down Expand Up @@ -304,4 +306,16 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def getByte(i: Int): Byte = {
values(i).asInstanceOf[MutableByte].value
}

override def setDecimal(ordinal: Int, value: BigDecimal): Unit = update(ordinal, value)

override def getDecimal(i: Int): BigDecimal = apply(i).asInstanceOf[BigDecimal]

override def setTimestamp(ordinal: Int, value: Timestamp): Unit = update(ordinal, value)

override def getTimestamp(i: Int): Timestamp = apply(i).asInstanceOf[Timestamp]

override def setBinary(ordinal: Int, value: Array[Byte]): Unit = update(ordinal, value)

override def getBinary(i: Int): Array[Byte] = apply(i).asInstanceOf[Array[Byte]]
}