Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for auto columns and user provided column mapping #148

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ In addition following options are supported
| isolationLevel | "READ_COMMITTED" | Specify the isolation level |
| tableLock | "false" | Implements an insert with TABLOCK option to improve write performance |
| schemaCheckEnabled | "true" | Disables strict dataframe and sql table schema check when set to false |
| columnsToWrite | "" | Enable user defined column mapping. Users provide list of column names and only write to these columns.|

Other [Bulk api options](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-bulk-copy-with-the-jdbc-driver?view=sql-server-2017#sqlserverbulkcopyoptions) can be set as options on the dataframe and will be passed to bulkcopy apis on write

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String])
val schemaCheckEnabled =
params.getOrElse("schemaCheckEnabled", "true").toBoolean

// user input column names array to match dataframe
val columnsToWrite =
params.getOrElse("columnsToWrite", "").toString

// Not a feature
// Only used for internally testing data idempotency
val testDataIdempotency =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getSchema, schemaString}
import com.microsoft.sqlserver.jdbc.{SQLServerBulkCopy, SQLServerBulkCopyOptions}

import scala.collection.mutable.ListBuffer

/**
* BulkCopyUtils Object implements common utility function used by both datapool and
*/
Expand All @@ -35,7 +33,7 @@ object BulkCopyUtils extends Logging {
* a connection, sets connection properties and does a BulkWrite. Called when writing data to
* master instance and data pools both. URL in options is used to create the relevant connection.
*
* @param itertor - iterator for row of the partition.
* @param iterator - iterator for row of the partition.
* @param dfColMetadata - array of ColumnMetadata type
* @param options - SQLServerBulkJdbcOptions with url for the connection
*/
Expand Down Expand Up @@ -179,47 +177,6 @@ object BulkCopyUtils extends Logging {
conn.createStatement.executeQuery(queryStr)
}

/**
* getComputedCols
* utility function to get computed columns.
* Use computed column names to exclude computed column when matching schema.
*/
private[spark] def getComputedCols(
conn: Connection,
table: String): List[String] = {
val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
val computedColRs = conn.createStatement.executeQuery(queryStr)
val computedCols = ListBuffer[String]()
while (computedColRs.next()) {
val colName = computedColRs.getString("name")
computedCols.append(colName)
}
computedCols.toList
}

/**
* dfComputedColCount
* utility function to get number of computed columns in dataframe.
* Use number of computed columns in dataframe to get number of non computed column in df,
* and compare with the number of non computed column in sql table
*/
private[spark] def dfComputedColCount(
dfColNames: List[String],
computedCols: List[String],
dfColCaseMap: Map[String, String],
isCaseSensitive: Boolean): Int ={
var dfComputedColCt = 0
for (j <- 0 to computedCols.length-1){
if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
!isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
&& dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
dfComputedColCt += 1
}
}
dfComputedColCt
}


/**
* getColMetadataMap
* Utility function convert result set meta data to array.
Expand Down Expand Up @@ -263,7 +220,7 @@ object BulkCopyUtils extends Logging {
val colMetaData = {
if(checkSchema) {
checkExTableType(conn, options)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.columnsToWrite)
} else {
defaultColMetadataMap(rs.getMetaData())
}
Expand All @@ -289,6 +246,7 @@ object BulkCopyUtils extends Logging {
* @param url: String,
* @param isCaseSensitive: Boolean
* @param strictSchemaCheck: Boolean
* @param columnsToWrite: String
*/
private[spark] def matchSchemas(
conn: Connection,
Expand All @@ -297,39 +255,37 @@ object BulkCopyUtils extends Logging {
rs: ResultSet,
url: String,
isCaseSensitive: Boolean,
strictSchemaCheck: Boolean): Array[ColumnMetadata]= {
strictSchemaCheck: Boolean,
columnsToWrite: String): Array[ColumnMetadata]= {
val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
zip df.schema.fieldNames.toList).toMap
val dfCols = df.schema

val tableCols = getSchema(rs, JdbcDialects.get(url))
val computedCols = getComputedCols(conn, dbtable)

val prefix = "Spark Dataframe and SQL Server table have differing"

if (computedCols.length == 0) {
// if columnsToWrite provided by user, use it for metadata mapping. If not, use sql table.
var metadataLen = tableCols.length
var columnsToWriteSet: Set[String] = Set()
if (columnsToWrite.isEmpty) {
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
} else if (strictSchemaCheck) {
val dfColNames = df.schema.fieldNames.toList
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
// if df has computed column(s), check column length using non computed column in df and table.
// non computed column number in df: dfCols.length - dfComputedColCt
// non computed column number in table: tableCols.length - computedCols.length
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
} else {
columnsToWriteSet = columnsToWrite.split(",").map(_.trim).toSet
logDebug(s"columnsToWrite: $columnsToWriteSet")
metadataLen = columnsToWriteSet.size
}


val result = new Array[ColumnMetadata](tableCols.length - computedCols.length)
var nonAutoColIndex = 0
var colMappingIndex = 0
val result = new Array[ColumnMetadata](metadataLen)

for (i <- 0 to tableCols.length-1) {
val tableColName = tableCols(i).name
var dfFieldIndex = -1
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
if (computedCols.contains(tableColName)) {
logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
// if columnsToWrite option provided, and sql column names not in it, skip this column mapping and ColumnMetadata
if (!columnsToWrite.isEmpty && !columnsToWriteSet.contains(tableColName)) {
logDebug(s"skipping col index $i col name $tableColName, user not provided in columnsToWrite list")
}else{
var dfColName:String = ""
if (isCaseSensitive) {
Expand Down Expand Up @@ -372,15 +328,15 @@ object BulkCopyUtils extends Logging {
s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}")

// Schema check passed for element, Create ColMetaData only for non auto generated column
result(nonAutoColIndex) = new ColumnMetadata(
// Schema check passed for element, Create ColMetaData for columns
result(colMappingIndex) = new ColumnMetadata(
rs.getMetaData().getColumnName(i+1),
rs.getMetaData().getColumnType(i+1),
rs.getMetaData().getPrecision(i+1),
rs.getMetaData().getScale(i+1),
dfFieldIndex
)
nonAutoColIndex += 1
colMappingIndex += 1
}
}
result
Expand Down