Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8944][SQL] Support casting between IntervalType and StringType #7355

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{Interval, UTF8String}


object Cast {
Expand Down Expand Up @@ -55,6 +55,9 @@ object Cast {

case (_, DateType) => true

case (StringType, IntervalType) => true
case (IntervalType, StringType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
Expand Down Expand Up @@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null
}

// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => Interval.fromString(s.toString))
case _ => _ => null
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
Expand Down Expand Up @@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
Expand Down Expand Up @@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")

case (StringType, IntervalType) =>
defineCodeGen(ctx, ev, c =>
s"org.apache.spark.unsafe.types.Interval.fromString($c.toString)")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L)))
}

test("case between string and interval") {
import org.apache.spark.unsafe.types.Interval

checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
"interval 1 years 3 months -3 days")
}

}
48 changes: 48 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.unsafe.types;

import java.io.Serializable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The internal representation of interval type.
Expand All @@ -30,6 +32,52 @@ public final class Interval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;

/**
* A function to generate regex which matches interval string's unit part like "3 years".
*
* First, we can leave out some units in interval string, and we only care about the value of
* unit, so here we use non-capturing group to wrap the actual regex.
* At the beginning of the actual regex, we should match spaces before the unit part.
* Next is the number part, starts with an optional "-" to represent negative value. We use
* capturing group to wrap this part as we need the value later.
* Finally is the unit name, ends with an optional "s".
*/
private static String unitRegex(String unit) {
return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more comments explaining the regex, since most people would need to look up documentation on the regex.

}

private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));

private static long toLong(String s) {
if (s == null) {
return 0;
} else {
return Long.valueOf(s);
}
}

public static Interval fromString(String s) {
if (s == null) {
return null;
}
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
} else {
long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
return new Interval((int) months, microseconds);
}
}

public final int months;
public final long microseconds;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,53 @@ public void toStringTest() {
i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
}

@Test
public void fromStringTest() {
testSingleUnit("year", 3, 36, 0);
testSingleUnit("month", 3, 3, 0);
testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK);
testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY);
testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR);
testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE);
testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND);
testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI);
testSingleUnit("microsecond", 3, 0, 3);

String s;
Interval i;

s = "interval -5 years 23 month";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s -> input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and i -> interval

i = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(s), i);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to test all the combinations of units?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not - but we should definitely test each unit at least once.


// Error cases
i = null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove i and use null directly.

"i" is a bad name since it is usually used for iterating over an array.


s = "interval 3month 1 hour";
assertEquals(Interval.fromString(s), i);

s = "interval 3 moth 1 hour";
assertEquals(Interval.fromString(s), i);

s = "interval";
assertEquals(Interval.fromString(s), i);

s = "int";
assertEquals(Interval.fromString(s), i);

s = "";
assertEquals(Interval.fromString(s), i);

s = null;
assertEquals(Interval.fromString(s), i);
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String s1 = "interval " + number + " " + unit;
String s2 = "interval " + number + " " + unit + "s";
Interval i = new Interval(months, microseconds);
assertEquals(Interval.fromString(s1), i);
assertEquals(Interval.fromString(s2), i);
}
}