Skip to content

Commit

Permalink
add support for temporal table / graph table
Browse files Browse the repository at this point in the history
  • Loading branch information
luxu1-ms committed Oct 13, 2021
1 parent 1d57c39 commit c1e3f41
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,10 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String])
val allowEncryptedValueModifications =
params.getOrElse("allowEncryptedValueModifications", "false").toBoolean


val schemaCheckEnabled =
params.getOrElse("schemaCheckEnabled", "true").toBoolean

val hideGraphColumns =
params.getOrElse("hideGraphColumns", "true").toBoolean

// Not a feature
// Only used for internally testing data idempotency
val testDataIdempotency =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,60 +180,47 @@ object BulkCopyUtils extends Logging {
}

/**
* getComputedCols
* utility function to get computed columns.
* Use computed column names to exclude computed column when matching schema.
* getAutoCols
* utility function to get auto generated columns.
* Use auto generated column names to exclude them when matching schema.
*/
private[spark] def getComputedCols(
private[spark] def getAutoCols(
conn: Connection,
table: String,
hideGraphColumns: Boolean): List[String] = {
// TODO can optimize this, also evaluate SQLi issues
val queryStr = if (hideGraphColumns) s"""IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
exec sp_executesql N'SELECT name
FROM sys.computed_columns
table: String): List[String] = {
// auto cols union computed cols, generated always cols, and node / edge table auto cols
val queryStr = s"""SELECT name
FROM sys.columns
WHERE object_id = OBJECT_ID(''${table}'')
UNION ALL
SELECT C.name
FROM sys.tables AS T
JOIN sys.columns AS C
ON T.object_id = C.object_id
WHERE T.object_id = OBJECT_ID(''${table}'')
AND (T.is_edge = 1 OR T.is_node = 1)
AND C.is_hidden = 0
AND C.graph_type = 2'
ELSE
SELECT name
FROM sys.computed_columns
WHERE object_id = OBJECT_ID('${table}')
AND (is_computed = 1 -- computed column
OR generated_always_type > 0 -- generated always / temporal table
OR (is_hidden = 0 AND graph_type = 2)) -- graph table
"""
else s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"

val computedColRs = conn.createStatement.executeQuery(queryStr)
val computedCols = ListBuffer[String]()
while (computedColRs.next()) {
val colName = computedColRs.getString("name")
computedCols.append(colName)
val autoColRs = conn.createStatement.executeQuery(queryStr)
val autoCols = ListBuffer[String]()
while (autoColRs.next()) {
val colName = autoColRs.getString("name")
autoCols.append(colName)
}
computedCols.toList
autoCols.toList
}

/**
* dfComputedColCount
* dfAutoColCount
* utility function to get number of computed columns in dataframe.
* Use number of computed columns in dataframe to get number of non computed column in df,
* and compare with the number of non computed column in sql table
*/
private[spark] def dfComputedColCount(
private[spark] def dfAutoColCount(
dfColNames: List[String],
computedCols: List[String],
autoCols: List[String],
dfColCaseMap: Map[String, String],
isCaseSensitive: Boolean): Int ={
var dfComputedColCt = 0
for (j <- 0 to computedCols.length-1){
if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
!isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
&& dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
for (j <- 0 to autoCols.length-1){
if (isCaseSensitive && dfColNames.contains(autoCols(j)) ||
!isCaseSensitive && dfColCaseMap.contains(autoCols(j).toLowerCase())
&& dfColCaseMap(autoCols(j).toLowerCase()) == autoCols(j)) {
dfComputedColCt += 1
}
}
Expand Down Expand Up @@ -284,7 +271,7 @@ SELECT name
val colMetaData = {
if(checkSchema) {
checkExTableType(conn, options)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
} else {
defaultColMetadataMap(rs.getMetaData())
}
Expand All @@ -310,7 +297,6 @@ SELECT name
* @param url: String,
* @param isCaseSensitive: Boolean
* @param strictSchemaCheck: Boolean
* @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables
*/
private[spark] def matchSchemas(
conn: Connection,
Expand All @@ -319,40 +305,39 @@ SELECT name
rs: ResultSet,
url: String,
isCaseSensitive: Boolean,
strictSchemaCheck: Boolean,
hideGraphColumns: Boolean): Array[ColumnMetadata]= {
strictSchemaCheck: Boolean): Array[ColumnMetadata]= {
val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
zip df.schema.fieldNames.toList).toMap
val dfCols = df.schema

val tableCols = getSchema(rs, JdbcDialects.get(url))
val computedCols = getComputedCols(conn, dbtable, hideGraphColumns)
val autoCols = getAutoCols(conn, dbtable)

val prefix = "Spark Dataframe and SQL Server table have differing"

if (computedCols.length == 0) {
if (autoCols.length == 0) {
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
} else if (strictSchemaCheck) {
val dfColNames = df.schema.fieldNames.toList
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
val dfComputedColCt = dfAutoColCount(dfColNames, autoCols, dfColCaseMap, isCaseSensitive)
// if df has computed column(s), check column length using non computed column in df and table.
// non computed column number in df: dfCols.length - dfComputedColCt
// non computed column number in table: tableCols.length - computedCols.length
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck,
// non computed column number in table: tableCols.length - autoCols.length
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-autoCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
}


val result = new Array[ColumnMetadata](tableCols.length - computedCols.length)
val result = new Array[ColumnMetadata](tableCols.length - autoCols.length)
var nonAutoColIndex = 0

for (i <- 0 to tableCols.length-1) {
val tableColName = tableCols(i).name
var dfFieldIndex = -1
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
if (computedCols.contains(tableColName)) {
logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
if (autoCols.contains(tableColName)) {
logDebug(s"skipping auto generated col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
}else{
var dfColName:String = ""
if (isCaseSensitive) {
Expand Down

0 comments on commit c1e3f41

Please sign in to comment.