Skip to content

Commit

Permalink
[SPARK-24196][SQL] Implement Spark's own GetSchemasOperation
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR fix SQL Client tools can't show DBs by implementing Spark's own `GetSchemasOperation`.

## How was this patch tested?
unit tests and manual tests
![image](https://user-images.githubusercontent.com/5399861/47782885-3dd5d400-dd3c-11e8-8586-59a8c15c7020.png)
![image](https://user-images.githubusercontent.com/5399861/47782899-4928ff80-dd3c-11e8-9d2d-ba9580ba4301.png)

Closes #22903 from wangyum/SPARK-24196.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
wangyum authored and gatorsmile committed Jan 8, 2019
1 parent 6f35ede commit 29a7d2d
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class GetSchemasOperation extends MetadataOperation {
.addStringColumn("TABLE_SCHEM", "Schema name.")
.addStringColumn("TABLE_CATALOG", "Catalog name.");

private RowSet rowSet;
protected RowSet rowSet;

protected GetSchemasOperation(HiveSession parentSession,
String catalogName, String schemaName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.thriftserver

import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType
import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.GetSchemasOperation
import org.apache.hive.service.cli.operation.MetadataOperation.DEFAULT_HIVE_CATALOG
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.sql.SQLContext

/**
* Spark's own GetSchemasOperation
*
* @param sqlContext SQLContext to use
* @param parentSession a HiveSession from SessionManager
* @param catalogName catalog name. null if not applicable.
* @param schemaName database name, null or a concrete database name
*/
private[hive] class SparkGetSchemasOperation(
sqlContext: SQLContext,
parentSession: HiveSession,
catalogName: String,
schemaName: String)
extends GetSchemasOperation(parentSession, catalogName, schemaName) {

override def runInternal(): Unit = {
setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

if (isAuthV2Enabled) {
val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName"
authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr)
}

try {
val schemaPattern = convertSchemaPattern(schemaName)
sqlContext.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName =>
rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG))
}
setState(OperationState.FINISHED)
} catch {
case e: HiveSQLException =>
setState(OperationState.ERROR)
throw e
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation}
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -63,6 +63,19 @@ private[thriftserver] class SparkSQLOperationManager()
operation
}

override def newGetSchemasOperation(
parentSession: HiveSession,
catalogName: String,
schemaName: String): GetSchemasOperation = synchronized {
val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
" initialized or had already closed.")
val operation = new SparkGetSchemasOperation(sqlContext, parentSession, catalogName, schemaName)
handleToOperation.put(operation.getHandle, operation)
logDebug(s"Created GetSchemasOperation with session=$parentSession.")
operation
}

def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
val iterator = confMap.entrySet().iterator()
while (iterator.hasNext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,22 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
}
}

def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) {
val user = System.getProperty("user.name")
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())

try {
statements.zip(fs).foreach { case (s, f) => f(s) }
} finally {
dbNames.foreach { name =>
statements(0).execute(s"DROP DATABASE IF EXISTS $name")
}
statements.foreach(_.close())
connections.foreach(_.close())
}
}

def withJdbcStatement(tableNames: String*)(f: Statement => Unit) {
withMultipleConnectionJdbcStatement(tableNames: _*)(f)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.thriftserver

import java.util.Properties

import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils}
import org.apache.hive.service.auth.PlainSaslHelper
import org.apache.hive.service.cli.thrift._
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket

class SparkMetadataOperationSuite extends HiveThriftJdbcTest {

override def mode: ServerMode.Value = ServerMode.binary

test("Spark's own GetSchemasOperation(SparkGetSchemasOperation)") {
def testGetSchemasOperation(
catalog: String,
schemaPattern: String)(f: HiveQueryResultSet => Unit): Unit = {
val rawTransport = new TSocket("localhost", serverPort)
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val client = new TCLIService.Client(new TBinaryProtocol(transport))
transport.open()
var rs: HiveQueryResultSet = null
try {
val openResp = client.OpenSession(new TOpenSessionReq)
val sessHandle = openResp.getSessionHandle
val schemaReq = new TGetSchemasReq(sessHandle)

if (catalog != null) {
schemaReq.setCatalogName(catalog)
}

if (schemaPattern == null) {
schemaReq.setSchemaName("%")
} else {
schemaReq.setSchemaName(schemaPattern)
}

val schemaResp = client.GetSchemas(schemaReq)
JdbcUtils.verifySuccess(schemaResp.getStatus)

rs = new HiveQueryResultSet.Builder(connection)
.setClient(client)
.setSessionHandle(sessHandle)
.setStmtHandle(schemaResp.getOperationHandle)
.build()
f(rs)
} finally {
rs.close()
connection.close()
transport.close()
rawTransport.close()
}
}

def checkResult(dbNames: Seq[String], rs: HiveQueryResultSet): Unit = {
if (dbNames.nonEmpty) {
for (i <- dbNames.indices) {
assert(rs.next())
assert(rs.getString("TABLE_SCHEM") === dbNames(i))
}
} else {
assert(!rs.next())
}
}

withDatabase("db1", "db2") { statement =>
Seq("CREATE DATABASE db1", "CREATE DATABASE db2").foreach(statement.execute)

testGetSchemasOperation(null, "%") { rs =>
checkResult(Seq("db1", "db2"), rs)
}
testGetSchemasOperation(null, "db1") { rs =>
checkResult(Seq("db1"), rs)
}
testGetSchemasOperation(null, "db_not_exist") { rs =>
checkResult(Seq.empty, rs)
}
testGetSchemasOperation(null, "db*") { rs =>
checkResult(Seq("db1", "db2"), rs)
}
}
}
}

0 comments on commit 29a7d2d

Please sign in to comment.