diff --git a/README.md b/README.md index f4f8d83a..cde17ea9 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ In addition following options are supported | isolationLevel | "READ_COMMITTED" | Specify the isolation level | | tableLock | "false" | Implements an insert with TABLOCK option to improve write performance | | schemaCheckEnabled | "true" | Disables strict dataframe and sql table schema check when set to false | +| columnsToWrite | "" | Enable user defined column mapping. Users provide list of column names and only write to these columns.| Other [Bulk api options](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-bulk-copy-with-the-jdbc-driver?view=sql-server-2017#sqlserverbulkcopyoptions) can be set as options on the dataframe and will be passed to bulkcopy apis on write diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala index 710f570f..4a2d095c 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala @@ -72,6 +72,10 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String]) val schemaCheckEnabled = params.getOrElse("schemaCheckEnabled", "true").toBoolean + // user input column names array to match dataframe + val columnsToWrite = + params.getOrElse("columnsToWrite", "").toString + // Not a feature // Only used for internally testing data idempotency val testDataIdempotency = diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala index 8b06e8e6..342b2ae1 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala @@ -22,8 +22,6 @@ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getSchema, schemaString} import com.microsoft.sqlserver.jdbc.{SQLServerBulkCopy, SQLServerBulkCopyOptions} -import scala.collection.mutable.ListBuffer - /** * BulkCopyUtils Object implements common utility function used by both datapool and */ @@ -35,7 +33,7 @@ object BulkCopyUtils extends Logging { * a connection, sets connection properties and does a BulkWrite. Called when writing data to * master instance and data pools both. URL in options is used to create the relevant connection. * - * @param itertor - iterator for row of the partition. + * @param iterator - iterator for row of the partition. * @param dfColMetadata - array of ColumnMetadata type * @param options - SQLServerBulkJdbcOptions with url for the connection */ @@ -179,47 +177,6 @@ object BulkCopyUtils extends Logging { conn.createStatement.executeQuery(queryStr) } - /** - * getComputedCols - * utility function to get computed columns. - * Use computed column names to exclude computed column when matching schema. - */ - private[spark] def getComputedCols( - conn: Connection, - table: String): List[String] = { - val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');" - val computedColRs = conn.createStatement.executeQuery(queryStr) - val computedCols = ListBuffer[String]() - while (computedColRs.next()) { - val colName = computedColRs.getString("name") - computedCols.append(colName) - } - computedCols.toList - } - - /** - * dfComputedColCount - * utility function to get number of computed columns in dataframe. - * Use number of computed columns in dataframe to get number of non computed column in df, - * and compare with the number of non computed column in sql table - */ - private[spark] def dfComputedColCount( - dfColNames: List[String], - computedCols: List[String], - dfColCaseMap: Map[String, String], - isCaseSensitive: Boolean): Int ={ - var dfComputedColCt = 0 - for (j <- 0 to computedCols.length-1){ - if (isCaseSensitive && dfColNames.contains(computedCols(j)) || - !isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase()) - && dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) { - dfComputedColCt += 1 - } - } - dfComputedColCt - } - - /** * getColMetadataMap * Utility function convert result set meta data to array. @@ -263,7 +220,7 @@ object BulkCopyUtils extends Logging { val colMetaData = { if(checkSchema) { checkExTableType(conn, options) - matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled) + matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.columnsToWrite) } else { defaultColMetadataMap(rs.getMetaData()) } @@ -289,6 +246,7 @@ object BulkCopyUtils extends Logging { * @param url: String, * @param isCaseSensitive: Boolean * @param strictSchemaCheck: Boolean + * @param columnsToWrite: String */ private[spark] def matchSchemas( conn: Connection, @@ -297,39 +255,37 @@ object BulkCopyUtils extends Logging { rs: ResultSet, url: String, isCaseSensitive: Boolean, - strictSchemaCheck: Boolean): Array[ColumnMetadata]= { + strictSchemaCheck: Boolean, + columnsToWrite: String): Array[ColumnMetadata]= { val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase) zip df.schema.fieldNames.toList).toMap val dfCols = df.schema val tableCols = getSchema(rs, JdbcDialects.get(url)) - val computedCols = getComputedCols(conn, dbtable) val prefix = "Spark Dataframe and SQL Server table have differing" - if (computedCols.length == 0) { + // if columnsToWrite provided by user, use it for metadata mapping. If not, use sql table. + var metadataLen = tableCols.length + var columnsToWriteSet: Set[String] = Set() + if (columnsToWrite.isEmpty) { assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck, s"${prefix} numbers of columns") - } else if (strictSchemaCheck) { - val dfColNames = df.schema.fieldNames.toList - val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive) - // if df has computed column(s), check column length using non computed column in df and table. - // non computed column number in df: dfCols.length - dfComputedColCt - // non computed column number in table: tableCols.length - computedCols.length - assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck, - s"${prefix} numbers of columns") + } else { + columnsToWriteSet = columnsToWrite.split(",").map(_.trim).toSet + logDebug(s"columnsToWrite: $columnsToWriteSet") + metadataLen = columnsToWriteSet.size } - - val result = new Array[ColumnMetadata](tableCols.length - computedCols.length) - var nonAutoColIndex = 0 + var colMappingIndex = 0 + val result = new Array[ColumnMetadata](metadataLen) for (i <- 0 to tableCols.length-1) { val tableColName = tableCols(i).name var dfFieldIndex = -1 - // set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata - if (computedCols.contains(tableColName)) { - logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex") + // if columnsToWrite option provided, and sql column names not in it, skip this column mapping and ColumnMetadata + if (!columnsToWrite.isEmpty && !columnsToWriteSet.contains(tableColName)) { + logDebug(s"skipping col index $i col name $tableColName, user not provided in columnsToWrite list") }else{ var dfColName:String = "" if (isCaseSensitive) { @@ -372,15 +328,15 @@ object BulkCopyUtils extends Logging { s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " + s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}") - // Schema check passed for element, Create ColMetaData only for non auto generated column - result(nonAutoColIndex) = new ColumnMetadata( + // Schema check passed for element, Create ColMetaData for columns + result(colMappingIndex) = new ColumnMetadata( rs.getMetaData().getColumnName(i+1), rs.getMetaData().getColumnType(i+1), rs.getMetaData().getPrecision(i+1), rs.getMetaData().getScale(i+1), dfFieldIndex ) - nonAutoColIndex += 1 + colMappingIndex += 1 } } result