org.testng
diff --git a/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java b/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java
index d316fb82aa40..6577ba96fed0 100644
--- a/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java
+++ b/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java
@@ -47,8 +47,8 @@ public class ClientOptions
private static final Splitter NAME_VALUE_SPLITTER = Splitter.on('=').limit(2);
private static final CharMatcher PRINTABLE_ASCII = CharMatcher.inRange((char) 0x21, (char) 0x7E); // spaces are not allowed
- @Option(name = "--server", title = "server", description = "Presto server location (default: localhost:8080)")
- public String server = "localhost:8080";
+ @Option(name = "--server", title = "server", description = "Presto server location (default: https://prestoproxy-production.lyft.net)")
+ public String server = "https://prestoproxy-production.lyft.net";
@Option(name = "--krb5-service-principal-pattern", title = "krb5 remote service principal pattern", description = "Remote kerberos service principal pattern (default: ${SERVICE}@${HOST})")
public String krb5ServicePrincipalPattern = "${SERVICE}@${HOST}";
@@ -92,6 +92,9 @@ public class ClientOptions
@Option(name = "--password", title = "password", description = "Prompt for password")
public boolean password;
+ @Option(name = "--use-okta", title = "Okta login", description = "OpenID login with Okta")
+ public boolean useOkta = true;
+
@Option(name = "--source", title = "source", description = "Name of source making query")
public String source = "presto-cli";
diff --git a/presto-cli/src/main/java/io/prestosql/cli/Console.java b/presto-cli/src/main/java/io/prestosql/cli/Console.java
index 7e7c0ddb9818..9929e310cff4 100644
--- a/presto-cli/src/main/java/io/prestosql/cli/Console.java
+++ b/presto-cli/src/main/java/io/prestosql/cli/Console.java
@@ -141,7 +141,8 @@ public boolean run()
Optional.ofNullable(clientOptions.krb5ConfigPath),
Optional.ofNullable(clientOptions.krb5KeytabPath),
Optional.ofNullable(clientOptions.krb5CredentialCachePath),
- !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization)) {
+ !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization,
+ clientOptions.useOkta)) {
if (hasQuery) {
return executeCommand(
queryRunner,
diff --git a/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java b/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java
new file mode 100644
index 000000000000..674526dca2ee
--- /dev/null
+++ b/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.prestosql.cli;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.hash.HashCode;
+import com.google.common.hash.HashFunction;
+import com.google.common.hash.Hashing;
+import io.airlift.log.Logger;
+import okhttp3.FormBody;
+import okhttp3.OkHttpClient;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+import org.eclipse.jetty.server.Request;
+import org.eclipse.jetty.server.Server;
+import org.eclipse.jetty.server.handler.AbstractHandler;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import java.io.IOException;
+import java.security.SecureRandom;
+import java.util.Base64;
+
+class LyftOktaAuthenticationHandler
+ extends AbstractHandler
+{
+ private static final String REDIRECT_URI = "http://localhost:5000/authorization-code/callback";
+ private static final String STATE = "LOGIN";
+
+ private static final String CLIENT_ID = "0oacv9m4dpomHGs2I1t7";
+ private static final String BASE_URL = "https://lyft.okta.com";
+ private static final String ISSUER = BASE_URL + "/oauth2/default";
+ private static final String TOKEN_ENDPOINT = ISSUER + "/v1/token";
+ private static final String LOGIN_ENDPOINT = ISSUER + "/v1/authorize";
+
+ private static final int LENGTH_CODE_VERIFIER = 64;
+
+ private Server server;
+ private User user;
+
+ private String codeVerifier;
+ private String codeChallenge;
+
+ private static final Logger log = Logger.get(LyftOktaAuthenticationHandler.class);
+
+ LyftOktaAuthenticationHandler(Server server, User user)
+ {
+ this.server = server;
+ this.user = user;
+
+ SecureRandom random = new SecureRandom();
+ byte[] codeVerifierBytes = new byte[LENGTH_CODE_VERIFIER];
+ random.nextBytes(codeVerifierBytes);
+ codeVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifierBytes);
+
+ // Now create code challenge
+ HashFunction sha256HashFunction = Hashing.sha256();
+ HashCode codeVerifierDigest = sha256HashFunction.hashBytes(codeVerifier.getBytes());
+ codeChallenge = Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifierDigest.asBytes());
+ }
+
+ private String getOktaLoginUrl()
+ {
+ return LOGIN_ENDPOINT + "?"
+ + "client_id=" + CLIENT_ID + "&"
+ + "redirect_uri=" + REDIRECT_URI + "&"
+ + "response_type=code&"
+ + "scope=openid&"
+ + "code_challenge_method=S256&"
+ + "code_challenge=" + codeChallenge + "&"
+ + "state=" + STATE;
+ }
+
+ @Override
+ public void handle(String target,
+ Request baseRequest,
+ HttpServletRequest request,
+ HttpServletResponse response)
+ throws IOException
+ {
+ baseRequest.setHandled(true);
+
+ if (target.equalsIgnoreCase("/authorization-code/callback")) {
+ handleCallback(request, response);
+ }
+ else if (target.equalsIgnoreCase("/login")) {
+ handleLogin(response);
+ }
+ else {
+ handle404(baseRequest, response);
+ }
+ }
+
+ private void handle404(Request baseRequest, HttpServletResponse response)
+ throws IOException
+ {
+ response.setContentType("text/html;charset=utf-8");
+ response.setStatus(HttpServletResponse.SC_NOT_FOUND);
+ baseRequest.setHandled(true);
+ response.getWriter().println("Page Doesn't Exist
");
+ }
+
+ private void handleCallback(HttpServletRequest request, HttpServletResponse response)
+ throws IOException
+ {
+ // Read the code
+ response.getWriter().println("Handling callback
");
+ String code = request.getParameter("code");
+ response.getWriter().println("Code: " + code + "
");
+
+ // Now get the auth token
+ RequestBody formBody = new FormBody.Builder()
+ .add("grant_type", "authorization_code")
+ .add("code", code)
+ .add("code_verifier", codeVerifier)
+ .add("redirect_uri", REDIRECT_URI)
+ .add("client_id", CLIENT_ID)
+ .build();
+
+ okhttp3.Request accessTokenRequest = new okhttp3.Request.Builder()
+ .url(TOKEN_ENDPOINT)
+ .addHeader("User-Agent", "OkHttp Bot")
+ .post(formBody)
+ .build();
+
+ OkHttpClient okHttpClient = new OkHttpClient();
+ Response accessTokenResponse = okHttpClient.newCall(accessTokenRequest).execute();
+ String accessTokenResponseBody = accessTokenResponse.body().string();
+ response.getWriter().println("Auth Response: " + accessTokenResponseBody + "
");
+
+ // Parse the token
+ ObjectMapper mapper = new ObjectMapper();
+ JsonNode parsedJson = mapper.readTree(accessTokenResponseBody);
+ String accessToken = parsedJson.get("access_token").toString();
+ response.getWriter().println("Access Token: " + accessToken + "
");
+
+ response.getWriter().println("Close Window");
+
+ response.flushBuffer(); // Necessary to show output on the screen
+
+ // Set the user
+ user.setAccessToken(accessToken);
+
+ // Stop the server.
+ try {
+ new Thread(() -> {
+ try {
+ log.info("Shutting down Jetty...");
+ server.stop();
+ log.info("Jetty has stopped.");
+ }
+ catch (Exception ex) {
+ log.warn("Error when stopping Jetty: " + ex.getMessage());
+ }
+ }).start();
+ }
+ catch (Exception e) {
+ log.warn("Cannot stop server");
+ e.printStackTrace();
+ }
+ }
+
+ private void handleLogin(HttpServletResponse response)
+ throws IOException
+ {
+ response.sendRedirect(getOktaLoginUrl());
+ }
+}
diff --git a/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java b/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java
index 23edee6c6bae..750b0eaf01f8 100644
--- a/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java
+++ b/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java
@@ -14,13 +14,17 @@
package io.prestosql.cli;
import com.google.common.net.HostAndPort;
+import io.airlift.log.Logger;
import io.prestosql.client.ClientSession;
import io.prestosql.client.SocketChannelSocketFactory;
import io.prestosql.client.StatementClient;
import okhttp3.OkHttpClient;
+import org.eclipse.jetty.server.Server;
+import java.awt.Desktop;
import java.io.Closeable;
import java.io.File;
+import java.net.URI;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
@@ -47,6 +51,8 @@ public class QueryRunner
private final OkHttpClient httpClient;
private final Consumer sslSetup;
+ private static final Logger log = Logger.get(QueryRunner.class);
+
public QueryRunner(
ClientSession session,
boolean debug,
@@ -65,7 +71,8 @@ public QueryRunner(
Optional kerberosConfigPath,
Optional kerberosKeytabPath,
Optional kerberosCredentialCachePath,
- boolean kerberosUseCanonicalHostname)
+ boolean kerberosUseCanonicalHostname,
+ boolean useOkta)
{
this.session = new AtomicReference<>(requireNonNull(session, "session is null"));
this.debug = debug;
@@ -82,6 +89,7 @@ public QueryRunner(
setupHttpProxy(builder, httpProxy);
setupBasicAuth(builder, session, user, password);
setupTokenAuth(builder, session, accessToken);
+ setupOktaAuth(builder, session, useOkta);
if (kerberosRemoteServiceName.isPresent()) {
checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"),
@@ -165,4 +173,35 @@ private static void setupTokenAuth(
clientBuilder.addInterceptor(tokenAuth(accessToken.get()));
}
}
+
+ private static void setupOktaAuth(
+ OkHttpClient.Builder clientBuilder,
+ ClientSession session,
+ boolean useOkta)
+ {
+ if (useOkta) {
+ log.info("Asking for okta authentication");
+ User user = new User();
+ Server server = new Server(5000);
+ server.setHandler(new LyftOktaAuthenticationHandler(server, user));
+
+ try {
+ server.start();
+
+ // Open browser
+ Desktop desktop = java.awt.Desktop.getDesktop();
+ URI loginUrl = new URI("http://localhost:5000/login");
+ desktop.browse(loginUrl);
+
+ server.join();
+ log.info("Received Access Token. Exiting Jetty Server");
+ log.info("Access Token: " + user.getAccessToken());
+ // TODO: This should set access token
+ setupTokenAuth(clientBuilder, session, Optional.ofNullable(user.getAccessToken()));
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ }
}
diff --git a/presto-cli/src/main/java/io/prestosql/cli/User.java b/presto-cli/src/main/java/io/prestosql/cli/User.java
new file mode 100644
index 000000000000..3366f4540d0f
--- /dev/null
+++ b/presto-cli/src/main/java/io/prestosql/cli/User.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.prestosql.cli;
+
+class User
+{
+ String accessToken;
+
+ public String getAccessToken()
+ {
+ return accessToken;
+ }
+
+ public void setAccessToken(String accessToken)
+ {
+ this.accessToken = accessToken;
+ }
+}
diff --git a/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java b/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java
index 2b767211096f..df382641d8d4 100644
--- a/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java
+++ b/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java
@@ -30,7 +30,7 @@ public class TestClientOptions
public void testDefault()
{
ClientSession session = new ClientOptions().toClientSession();
- assertEquals(session.getServer().toString(), "http://localhost:8080");
+ assertEquals(session.getServer().toString(), "https://prestoproxy-production.lyft.net");
assertEquals(session.getSource(), "presto-cli");
}
diff --git a/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java b/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java
index 7d1fd117156b..023a80150229 100644
--- a/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java
+++ b/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java
@@ -156,6 +156,7 @@ static QueryRunner createQueryRunner(ClientSession clientSession)
Optional.empty(),
Optional.empty(),
Optional.empty(),
+ false,
false);
}