diff --git a/presto-cli/pom.xml b/presto-cli/pom.xml index b74e62f2ba60..593d0822963b 100644 --- a/presto-cli/pom.xml +++ b/presto-cli/pom.xml @@ -63,6 +63,11 @@ javax.inject + + javax.servlet + javax.servlet-api + + com.google.guava guava @@ -93,6 +98,17 @@ jackson-core + + com.fasterxml.jackson.core + jackson-databind + + + + org.eclipse.jetty + jetty-server + 9.4.14.v20181114 + + org.testng diff --git a/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java b/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java index d316fb82aa40..6577ba96fed0 100644 --- a/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java +++ b/presto-cli/src/main/java/io/prestosql/cli/ClientOptions.java @@ -47,8 +47,8 @@ public class ClientOptions private static final Splitter NAME_VALUE_SPLITTER = Splitter.on('=').limit(2); private static final CharMatcher PRINTABLE_ASCII = CharMatcher.inRange((char) 0x21, (char) 0x7E); // spaces are not allowed - @Option(name = "--server", title = "server", description = "Presto server location (default: localhost:8080)") - public String server = "localhost:8080"; + @Option(name = "--server", title = "server", description = "Presto server location (default: https://prestoproxy-production.lyft.net)") + public String server = "https://prestoproxy-production.lyft.net"; @Option(name = "--krb5-service-principal-pattern", title = "krb5 remote service principal pattern", description = "Remote kerberos service principal pattern (default: ${SERVICE}@${HOST})") public String krb5ServicePrincipalPattern = "${SERVICE}@${HOST}"; @@ -92,6 +92,9 @@ public class ClientOptions @Option(name = "--password", title = "password", description = "Prompt for password") public boolean password; + @Option(name = "--use-okta", title = "Okta login", description = "OpenID login with Okta") + public boolean useOkta = true; + @Option(name = "--source", title = "source", description = "Name of source making query") public String source = "presto-cli"; diff --git a/presto-cli/src/main/java/io/prestosql/cli/Console.java b/presto-cli/src/main/java/io/prestosql/cli/Console.java index 7e7c0ddb9818..9929e310cff4 100644 --- a/presto-cli/src/main/java/io/prestosql/cli/Console.java +++ b/presto-cli/src/main/java/io/prestosql/cli/Console.java @@ -141,7 +141,8 @@ public boolean run() Optional.ofNullable(clientOptions.krb5ConfigPath), Optional.ofNullable(clientOptions.krb5KeytabPath), Optional.ofNullable(clientOptions.krb5CredentialCachePath), - !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization)) { + !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization, + clientOptions.useOkta)) { if (hasQuery) { return executeCommand( queryRunner, diff --git a/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java b/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java new file mode 100644 index 000000000000..674526dca2ee --- /dev/null +++ b/presto-cli/src/main/java/io/prestosql/cli/LyftOktaAuthenticationHandler.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.cli; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.hash.HashCode; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import io.airlift.log.Logger; +import okhttp3.FormBody; +import okhttp3.OkHttpClient; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.handler.AbstractHandler; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Base64; + +class LyftOktaAuthenticationHandler + extends AbstractHandler +{ + private static final String REDIRECT_URI = "http://localhost:5000/authorization-code/callback"; + private static final String STATE = "LOGIN"; + + private static final String CLIENT_ID = "0oacv9m4dpomHGs2I1t7"; + private static final String BASE_URL = "https://lyft.okta.com"; + private static final String ISSUER = BASE_URL + "/oauth2/default"; + private static final String TOKEN_ENDPOINT = ISSUER + "/v1/token"; + private static final String LOGIN_ENDPOINT = ISSUER + "/v1/authorize"; + + private static final int LENGTH_CODE_VERIFIER = 64; + + private Server server; + private User user; + + private String codeVerifier; + private String codeChallenge; + + private static final Logger log = Logger.get(LyftOktaAuthenticationHandler.class); + + LyftOktaAuthenticationHandler(Server server, User user) + { + this.server = server; + this.user = user; + + SecureRandom random = new SecureRandom(); + byte[] codeVerifierBytes = new byte[LENGTH_CODE_VERIFIER]; + random.nextBytes(codeVerifierBytes); + codeVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifierBytes); + + // Now create code challenge + HashFunction sha256HashFunction = Hashing.sha256(); + HashCode codeVerifierDigest = sha256HashFunction.hashBytes(codeVerifier.getBytes()); + codeChallenge = Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifierDigest.asBytes()); + } + + private String getOktaLoginUrl() + { + return LOGIN_ENDPOINT + "?" + + "client_id=" + CLIENT_ID + "&" + + "redirect_uri=" + REDIRECT_URI + "&" + + "response_type=code&" + + "scope=openid&" + + "code_challenge_method=S256&" + + "code_challenge=" + codeChallenge + "&" + + "state=" + STATE; + } + + @Override + public void handle(String target, + Request baseRequest, + HttpServletRequest request, + HttpServletResponse response) + throws IOException + { + baseRequest.setHandled(true); + + if (target.equalsIgnoreCase("/authorization-code/callback")) { + handleCallback(request, response); + } + else if (target.equalsIgnoreCase("/login")) { + handleLogin(response); + } + else { + handle404(baseRequest, response); + } + } + + private void handle404(Request baseRequest, HttpServletResponse response) + throws IOException + { + response.setContentType("text/html;charset=utf-8"); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + baseRequest.setHandled(true); + response.getWriter().println("

Page Doesn't Exist

"); + } + + private void handleCallback(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + // Read the code + response.getWriter().println("

Handling callback

"); + String code = request.getParameter("code"); + response.getWriter().println("

Code: " + code + "

"); + + // Now get the auth token + RequestBody formBody = new FormBody.Builder() + .add("grant_type", "authorization_code") + .add("code", code) + .add("code_verifier", codeVerifier) + .add("redirect_uri", REDIRECT_URI) + .add("client_id", CLIENT_ID) + .build(); + + okhttp3.Request accessTokenRequest = new okhttp3.Request.Builder() + .url(TOKEN_ENDPOINT) + .addHeader("User-Agent", "OkHttp Bot") + .post(formBody) + .build(); + + OkHttpClient okHttpClient = new OkHttpClient(); + Response accessTokenResponse = okHttpClient.newCall(accessTokenRequest).execute(); + String accessTokenResponseBody = accessTokenResponse.body().string(); + response.getWriter().println("

Auth Response: " + accessTokenResponseBody + "

"); + + // Parse the token + ObjectMapper mapper = new ObjectMapper(); + JsonNode parsedJson = mapper.readTree(accessTokenResponseBody); + String accessToken = parsedJson.get("access_token").toString(); + response.getWriter().println("

Access Token: " + accessToken + "

"); + + response.getWriter().println("Close Window"); + + response.flushBuffer(); // Necessary to show output on the screen + + // Set the user + user.setAccessToken(accessToken); + + // Stop the server. + try { + new Thread(() -> { + try { + log.info("Shutting down Jetty..."); + server.stop(); + log.info("Jetty has stopped."); + } + catch (Exception ex) { + log.warn("Error when stopping Jetty: " + ex.getMessage()); + } + }).start(); + } + catch (Exception e) { + log.warn("Cannot stop server"); + e.printStackTrace(); + } + } + + private void handleLogin(HttpServletResponse response) + throws IOException + { + response.sendRedirect(getOktaLoginUrl()); + } +} diff --git a/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java b/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java index 23edee6c6bae..750b0eaf01f8 100644 --- a/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java +++ b/presto-cli/src/main/java/io/prestosql/cli/QueryRunner.java @@ -14,13 +14,17 @@ package io.prestosql.cli; import com.google.common.net.HostAndPort; +import io.airlift.log.Logger; import io.prestosql.client.ClientSession; import io.prestosql.client.SocketChannelSocketFactory; import io.prestosql.client.StatementClient; import okhttp3.OkHttpClient; +import org.eclipse.jetty.server.Server; +import java.awt.Desktop; import java.io.Closeable; import java.io.File; +import java.net.URI; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -47,6 +51,8 @@ public class QueryRunner private final OkHttpClient httpClient; private final Consumer sslSetup; + private static final Logger log = Logger.get(QueryRunner.class); + public QueryRunner( ClientSession session, boolean debug, @@ -65,7 +71,8 @@ public QueryRunner( Optional kerberosConfigPath, Optional kerberosKeytabPath, Optional kerberosCredentialCachePath, - boolean kerberosUseCanonicalHostname) + boolean kerberosUseCanonicalHostname, + boolean useOkta) { this.session = new AtomicReference<>(requireNonNull(session, "session is null")); this.debug = debug; @@ -82,6 +89,7 @@ public QueryRunner( setupHttpProxy(builder, httpProxy); setupBasicAuth(builder, session, user, password); setupTokenAuth(builder, session, accessToken); + setupOktaAuth(builder, session, useOkta); if (kerberosRemoteServiceName.isPresent()) { checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"), @@ -165,4 +173,35 @@ private static void setupTokenAuth( clientBuilder.addInterceptor(tokenAuth(accessToken.get())); } } + + private static void setupOktaAuth( + OkHttpClient.Builder clientBuilder, + ClientSession session, + boolean useOkta) + { + if (useOkta) { + log.info("Asking for okta authentication"); + User user = new User(); + Server server = new Server(5000); + server.setHandler(new LyftOktaAuthenticationHandler(server, user)); + + try { + server.start(); + + // Open browser + Desktop desktop = java.awt.Desktop.getDesktop(); + URI loginUrl = new URI("http://localhost:5000/login"); + desktop.browse(loginUrl); + + server.join(); + log.info("Received Access Token. Exiting Jetty Server"); + log.info("Access Token: " + user.getAccessToken()); + // TODO: This should set access token + setupTokenAuth(clientBuilder, session, Optional.ofNullable(user.getAccessToken())); + } + catch (Exception e) { + e.printStackTrace(); + } + } + } } diff --git a/presto-cli/src/main/java/io/prestosql/cli/User.java b/presto-cli/src/main/java/io/prestosql/cli/User.java new file mode 100644 index 000000000000..3366f4540d0f --- /dev/null +++ b/presto-cli/src/main/java/io/prestosql/cli/User.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.cli; + +class User +{ + String accessToken; + + public String getAccessToken() + { + return accessToken; + } + + public void setAccessToken(String accessToken) + { + this.accessToken = accessToken; + } +} diff --git a/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java b/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java index 2b767211096f..df382641d8d4 100644 --- a/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java +++ b/presto-cli/src/test/java/io/prestosql/cli/TestClientOptions.java @@ -30,7 +30,7 @@ public class TestClientOptions public void testDefault() { ClientSession session = new ClientOptions().toClientSession(); - assertEquals(session.getServer().toString(), "http://localhost:8080"); + assertEquals(session.getServer().toString(), "https://prestoproxy-production.lyft.net"); assertEquals(session.getSource(), "presto-cli"); } diff --git a/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java b/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java index 7d1fd117156b..023a80150229 100644 --- a/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java +++ b/presto-cli/src/test/java/io/prestosql/cli/TestQueryRunner.java @@ -156,6 +156,7 @@ static QueryRunner createQueryRunner(ClientSession clientSession) Optional.empty(), Optional.empty(), Optional.empty(), + false, false); }