forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:apache/spark into viz2
- Loading branch information
Showing
15 changed files
with
1,174 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import itertools | ||
|
||
__all__ = ['ParamGridBuilder'] | ||
|
||
|
||
class ParamGridBuilder(object): | ||
""" | ||
Builder for a param grid used in grid search-based model selection. | ||
>>> from classification import LogisticRegression | ||
>>> lr = LogisticRegression() | ||
>>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ | ||
.baseOn([lr.predictionCol, 'p']) \ | ||
.addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ | ||
.addGrid(lr.maxIter, [1, 5]) \ | ||
.addGrid(lr.featuresCol, ['f']) \ | ||
.build() | ||
>>> expected = [ \ | ||
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ | ||
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ | ||
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ | ||
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ | ||
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ | ||
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] | ||
>>> len(output) == len(expected) | ||
True | ||
>>> all([m in expected for m in output]) | ||
True | ||
""" | ||
|
||
def __init__(self): | ||
self._param_grid = {} | ||
|
||
def addGrid(self, param, values): | ||
""" | ||
Sets the given parameters in this grid to fixed values. | ||
""" | ||
self._param_grid[param] = values | ||
|
||
return self | ||
|
||
def baseOn(self, *args): | ||
""" | ||
Sets the given parameters in this grid to fixed values. | ||
Accepts either a parameter dictionary or a list of (parameter, value) pairs. | ||
""" | ||
if isinstance(args[0], dict): | ||
self.baseOn(*args[0].items()) | ||
else: | ||
for (param, value) in args: | ||
self.addGrid(param, [value]) | ||
|
||
return self | ||
|
||
def build(self): | ||
""" | ||
Builds and returns all combinations of parameters specified | ||
by the param grid. | ||
""" | ||
keys = self._param_grid.keys() | ||
grid_values = self._param_grid.values() | ||
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.hive.client | ||
|
||
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} | ||
|
||
case class HiveDatabase( | ||
name: String, | ||
location: String) | ||
|
||
abstract class TableType { val name: String } | ||
case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } | ||
case object IndexTable extends TableType { override val name = "INDEX_TABLE" } | ||
case object ManagedTable extends TableType { override val name = "MANAGED_TABLE" } | ||
case object VirtualView extends TableType { override val name = "VIRTUAL_VIEW" } | ||
|
||
case class HiveStorageDescriptor( | ||
location: String, | ||
inputFormat: String, | ||
outputFormat: String, | ||
serde: String) | ||
|
||
case class HivePartition( | ||
values: Seq[String], | ||
storage: HiveStorageDescriptor) | ||
|
||
case class HiveColumn(name: String, hiveType: String, comment: String) | ||
case class HiveTable( | ||
specifiedDatabase: Option[String], | ||
name: String, | ||
schema: Seq[HiveColumn], | ||
partitionColumns: Seq[HiveColumn], | ||
properties: Map[String, String], | ||
serdeProperties: Map[String, String], | ||
tableType: TableType, | ||
location: Option[String] = None, | ||
inputFormat: Option[String] = None, | ||
outputFormat: Option[String] = None, | ||
serde: Option[String] = None) { | ||
|
||
@transient | ||
private[client] var client: ClientInterface = _ | ||
|
||
private[client] def withClient(ci: ClientInterface): this.type = { | ||
client = ci | ||
this | ||
} | ||
|
||
def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) | ||
|
||
def isPartitioned: Boolean = partitionColumns.nonEmpty | ||
|
||
def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) | ||
|
||
// Hive does not support backticks when passing names to the client. | ||
def qualifiedName: String = s"$database.$name" | ||
} | ||
|
||
/** | ||
* An externally visible interface to the Hive client. This interface is shared across both the | ||
* internal and external classloaders for a given version of Hive and thus must expose only | ||
* shared classes. | ||
*/ | ||
trait ClientInterface { | ||
/** | ||
* Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will | ||
* result in one string. | ||
*/ | ||
def runSqlHive(sql: String): Seq[String] | ||
|
||
/** Returns the names of all tables in the given database. */ | ||
def listTables(dbName: String): Seq[String] | ||
|
||
/** Returns the name of the active database. */ | ||
def currentDatabase: String | ||
|
||
/** Returns the metadata for specified database, throwing an exception if it doesn't exist */ | ||
def getDatabase(name: String): HiveDatabase = { | ||
getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) | ||
} | ||
|
||
/** Returns the metadata for a given database, or None if it doesn't exist. */ | ||
def getDatabaseOption(name: String): Option[HiveDatabase] | ||
|
||
/** Returns the specified table, or throws [[NoSuchTableException]]. */ | ||
def getTable(dbName: String, tableName: String): HiveTable = { | ||
getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException) | ||
} | ||
|
||
/** Returns the metadata for the specified table or None if it doens't exist. */ | ||
def getTableOption(dbName: String, tableName: String): Option[HiveTable] | ||
|
||
/** Creates a table with the given metadata. */ | ||
def createTable(table: HiveTable): Unit | ||
|
||
/** Updates the given table with new metadata. */ | ||
def alterTable(table: HiveTable): Unit | ||
|
||
/** Creates a new database with the given name. */ | ||
def createDatabase(database: HiveDatabase): Unit | ||
|
||
/** Returns all partitions for the given table. */ | ||
def getAllPartitions(hTable: HiveTable): Seq[HivePartition] | ||
|
||
/** Loads a static partition into an existing table. */ | ||
def loadPartition( | ||
loadPath: String, | ||
tableName: String, | ||
partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering | ||
replace: Boolean, | ||
holdDDLTime: Boolean, | ||
inheritTableSpecs: Boolean, | ||
isSkewedStoreAsSubdir: Boolean): Unit | ||
|
||
/** Loads data into an existing table. */ | ||
def loadTable( | ||
loadPath: String, // TODO URI | ||
tableName: String, | ||
replace: Boolean, | ||
holdDDLTime: Boolean): Unit | ||
|
||
/** Loads new dynamic partitions into an existing table. */ | ||
def loadDynamicPartitions( | ||
loadPath: String, | ||
tableName: String, | ||
partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering | ||
replace: Boolean, | ||
numDP: Int, | ||
holdDDLTime: Boolean, | ||
listBucketingEnabled: Boolean): Unit | ||
|
||
/** Used for testing only. Removes all metadata from this instance of Hive. */ | ||
def reset(): Unit | ||
} |
Oops, something went wrong.