Skip to content

Commit

Permalink
[#23034] YSQL: Add Support for OIDC IDP URL (jwt_jwks_url) to fetch a…
Browse files Browse the repository at this point in the history
…nd refresh JKWS

Summary:
Added support for jwt_jwks_url in JWT Authentication to fetch and refresh JWKS from url.

**Configuration**
In ysql_hba_conf_csv with method jwt can provide jwt_jwks_url to pass the url. Atleast one of the
jwt_jwks_path or jwt_jwks_url must be provided, if both are provided then jwt_jwks_url will be used.

**Fetching**
Implementation is as follows:

1. Fetches the keys from url using libcurl (EasyCurl implentation)
2. The keys are used for authentication in a similar fashion as with jwt_jwks_path

JIRA: DB-11962

Test Plan:
**Automated**
./yb_build.sh --java-test 'org.yb.pgsql.TestJWTAuth'

Added similar tests as jwt_jwks_path, to check for valid url, json, and authentication. Also added
checks for invalid url and invalid json.

Reviewers: stiwary, skumar

Reviewed By: stiwary

Subscribers: yql

Differential Revision: https://phorge.dev.yugabyte.com/D36468
  • Loading branch information
utkarsh-um-yb committed Jul 16, 2024
1 parent 6ec058d commit 18bb9b8
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 13 deletions.
12 changes: 12 additions & 0 deletions java/yb-pgsql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,17 @@
<version>9.31</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.10.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<version>4.1.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
159 changes: 150 additions & 9 deletions java/yb-pgsql/src/test/java/org/yb/pgsql/TestJWTAuth.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.yb.util.Pair;

import com.google.common.base.Strings;
import com.google.common.util.concurrent.ExecutionError;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
Expand All @@ -76,6 +77,11 @@
import com.nimbusds.jwt.SignedJWT;
import com.yugabyte.util.PSQLException;

import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;

@RunWith(value = YBTestRunner.class)
public class TestJWTAuth extends BasePgSQLTest {
private static final Logger LOG = LoggerFactory.getLogger(TestJWTAuth.class);
Expand Down Expand Up @@ -216,11 +222,9 @@ private static JWKSet createJwks() throws Exception {
});
}

// Sets up JWT authentication with the provided configuration params.
// Enables JWT auth on "testuser1" while the remaining users authenticate via trust.
private void setJWTConfigAndRestartCluster(List<String> allowedIssuers,
List<String> allowedAudiences, String jwksPath, String matchingClaimKey, String mapName,
String identFileContents) throws Exception {
private void setJWTConfigAndRestartClusterWithUrl(List<String> allowedIssuers,
List<String> allowedAudiences, String jwksPath, String matchingClaimKey, String mapName,
String identFileContents, String jwksUrl) throws Exception {
String issuersCsv = String.join(",", allowedIssuers);
String audiencesCsv = String.join(",", allowedAudiences);

Expand All @@ -231,22 +235,32 @@ private void setJWTConfigAndRestartCluster(List<String> allowedIssuers,

Map<String, String> flagMap = super.getTServerFlags();
String hba_conf_value = "";
String jwksPathConfString = "";
if(!Strings.isNullOrEmpty(jwksPath)){
jwksPathConfString = String.format("jwt_jwks_path=%s ", jwksPath);
}
String jwksUrlConfString = "";
if(!Strings.isNullOrEmpty(jwksUrl)){
jwksUrlConfString = String.format("jwt_jwks_url=%s ", jwksUrl);
}
if (Strings.isNullOrEmpty(mapName)) {
hba_conf_value = String.format("\"host all yugabyte 0.0.0.0/0 trust\","
+ "\"host all yugabyte_test 0.0.0.0/0 trust\","
+ "\"host all all 0.0.0.0/0 jwt "
+ "jwt_jwks_path=%s "
+ jwksPathConfString
+ jwksUrlConfString
+ "jwt_issuers=\"\"%s\"\" "
+ "jwt_audiences=\"\"%s\"\" %s\"",
jwksPath, issuersCsv, audiencesCsv, matchingClaimKeyValues);
issuersCsv, audiencesCsv, matchingClaimKeyValues);
} else {
hba_conf_value = String.format("\"host all yugabyte 0.0.0.0/0 trust\","
+ "\"host all yugabyte_test 0.0.0.0/0 trust\","
+ "\"host all all 0.0.0.0/0 jwt "
+ "jwt_jwks_path=%s "
+ jwksPathConfString
+ jwksUrlConfString
+ "jwt_issuers=\"\"%s\"\" "
+ "jwt_audiences=\"\"%s\"\" %s map=%s\"",
jwksPath, issuersCsv, audiencesCsv, matchingClaimKeyValues, mapName);
issuersCsv, audiencesCsv, matchingClaimKeyValues, mapName);
}

flagMap.put("ysql_hba_conf_csv", hba_conf_value);
Expand All @@ -261,6 +275,15 @@ private void setJWTConfigAndRestartCluster(List<String> allowedIssuers,
LOG.info("Cluster restart finished");
}

// Sets up JWT authentication with the provided configuration params.
// Enables JWT auth on "testuser1" while the remaining users authenticate via trust.
private void setJWTConfigAndRestartCluster(List<String> allowedIssuers,
List<String> allowedAudiences, String jwksPath, String matchingClaimKey, String mapName,
String identFileContents) throws Exception {
setJWTConfigAndRestartClusterWithUrl(allowedIssuers, allowedAudiences, jwksPath,
matchingClaimKey, mapName, identFileContents, /* jwksUrl */ "");
}

// groupsOrRoles needs to be passed separately since Nimbus is not able to serialize the List when
// it receives it as a object.
private static String createJWT(JWSAlgorithm algorithm, JWKSet jwks, String keyId, String sub,
Expand Down Expand Up @@ -362,6 +385,64 @@ public void authWithSubjectWithoutIdent() throws Exception {
assertFailedAuthentication(passRoleUserConnBldr, "123");
}

@Test
public void authWithSubjectWithoutIdentFromUrl() throws Exception {
try (MockWebServer server = new MockWebServer()) {
Dispatcher mDispatcher = new Dispatcher() {
@Override
public MockResponse dispatch(RecordedRequest request) {
if (request.getPath().contains("/jwks_keys")) {
return new MockResponse().setResponseCode(200)
.setBody(jwks.toString(true));
}
if (request.getPath().contains("/invalid_json")) {
return new MockResponse().setResponseCode(200)
.setBody("invalid json");
}
return new MockResponse().setResponseCode(404);
}
};
server.setDispatcher(mDispatcher);
server.start();
String serverUrl = String.format("\"\"http://%s:%s/jwks_keys\"\"",
server.getHostName(), server.getPort());

setJWTConfigAndRestartClusterWithUrl(ALLOWED_ISSUERS, ALLOWED_AUDIENCES, /* jwksPath */ "",
/* matchingClaimKey */ "", /* mapName */ "", /* identFileContents */ "", serverUrl);

List<Pair<JWSAlgorithm, String>> keysWithAlgorithms =
new ArrayList<Pair<JWSAlgorithm, String>>() {
{
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.RS256, RS256_KEYID));
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.PS256, PS256_KEYID));
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.ES256, ES256_KEYID));
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.RS256, RS256_KEYID_WITH_X5C));
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.PS256, PS256_KEYID_WITH_X5C));
add(new Pair<JWSAlgorithm, String>(JWSAlgorithm.ES256, ES256_KEYID_WITH_X5C));
}
};

try (Statement statement = connection.createStatement()) {
statement.execute("CREATE ROLE testuser1 LOGIN");
}

ConnectionBuilder passRoleUserConnBldr = getConnectionBuilder().withUser("testuser1");

// Ensure that login works with each key type.
for (Pair<JWSAlgorithm, String> key : keysWithAlgorithms) {
String jwt = createJWT(key.getFirst(), jwks, key.getSecond(), "testuser1",
"login.issuer1.secured.example.com/2ac843f8-2156-11ee-be56-0242ac120002/v2.0",
"795c2b42-2156-11ee-be56-0242ac120002", ISSUED_AT_TIME, EXPIRATION_TIME, null);
assertSuccessfulAuthentication(passRoleUserConnBldr, jwt);
}

// Basic JWT login with incorrect password.
assertFailedAuthentication(passRoleUserConnBldr, "123");

server.shutdown();
}
}

@Test
public void authWithSubjectWithIdent() throws Exception {
// Map IDP {name}@example.com to YSQL {name}.
Expand Down Expand Up @@ -664,6 +745,36 @@ public void invalidJWTJwksPath() throws Exception {
testFailedAuthentication(jwks);
}

@Test
public void invalidJWTJwksUrl() throws Exception {
try (MockWebServer server = new MockWebServer()) {
Dispatcher mDispatcher = new Dispatcher() {
@Override
public MockResponse dispatch(RecordedRequest request) {
if (request.getPath().contains("/jwks_keys")) {
return new MockResponse().setResponseCode(200)
.setBody(jwks.toString(true));
}
if (request.getPath().contains("/invalid_json")) {
return new MockResponse().setResponseCode(200)
.setBody("invalid json");
}
return new MockResponse().setResponseCode(404);
}
};
server.setDispatcher(mDispatcher);
server.start();
String serverUrl = String.format("\"\"http://%s:%s/random_url\"\"",
server.getHostName(), server.getPort());

setJWTConfigAndRestartClusterWithUrl(ALLOWED_ISSUERS, ALLOWED_AUDIENCES, /* jwksPath */ "",
/* matchingClaimKey */ "", /* mapName */ "", /* identFileContents */ "", serverUrl);

testFailedAuthentication(jwks);
server.shutdown();
}
}

@Test
public void invalidJWKSJson() throws Exception {
String jwksPath = populateJWKSFile("some_invalid_json");
Expand All @@ -673,6 +784,36 @@ public void invalidJWKSJson() throws Exception {
testFailedAuthentication(jwks);
}

@Test
public void invalidJWKSJsonFromUrl() throws Exception {
try (MockWebServer server = new MockWebServer()) {
Dispatcher mDispatcher = new Dispatcher() {
@Override
public MockResponse dispatch(RecordedRequest request) {
if (request.getPath().contains("/jwks_keys")) {
return new MockResponse().setResponseCode(200)
.setBody(jwks.toString(true));
}
if (request.getPath().contains("/invalid_json")) {
return new MockResponse().setResponseCode(200)
.setBody("invalid json");
}
return new MockResponse().setResponseCode(404);
}
};
server.setDispatcher(mDispatcher);
server.start();
String serverUrl = String.format("\"\"http://%s:%s/invalid_json\"\"",
server.getHostName(), server.getPort());

setJWTConfigAndRestartClusterWithUrl(ALLOWED_ISSUERS, ALLOWED_AUDIENCES, /* jwksPath */ "",
/* matchingClaimKey */ "", /* mapName */ "", /* identFileContents */ "", serverUrl);

testFailedAuthentication(jwks);
server.shutdown();
}
}

// Asserts that the cluster restart failed by expecting an exception. There doesn't seem to be a
// more accurate way of checking that.
private void assertClusterRestartFailure(List<String> allowedIssuers,
Expand Down
38 changes: 37 additions & 1 deletion src/postgres/src/backend/libpq/auth.c
Original file line number Diff line number Diff line change
Expand Up @@ -3546,6 +3546,8 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
static char *ybReadFile(const char *outer_filename, const char *inc_filename,
int elevel);

static char *ybReadFromUrl(const char *url);

static void
ybGetJwtAuthOptionsFromPortAndJwks(Port *port, char *jwks,
YBCPgJwtAuthOptions *opt)
Expand Down Expand Up @@ -3576,8 +3578,12 @@ YbCheckJwtAuth(Port *port)
/*
* Read the jwks file before the password prompt so that we fail fast if we
* fail to read the jwks file or the content is invalid.
* Check if jwt_jwks_url is provided then use that otherwise use jwt_jwks_path
*/
jwks = ybReadFile(HbaFileName, port->hba->yb_jwt_jwks_path, LOG);
if(port->hba->yb_jwt_jwks_url)
jwks = ybReadFromUrl(port->hba->yb_jwt_jwks_url);
else
jwks = ybReadFile(HbaFileName, port->hba->yb_jwt_jwks_path, LOG);
if (jwks == NULL)
return STATUS_ERROR;

Expand Down Expand Up @@ -3677,3 +3683,33 @@ ybReadFile(const char *outer_filename, const char *inc_filename, int elevel)
pfree(file_fullname);
return file_contents;
}

static char *
ybReadFromUrl(const char *url)
{
char *url_contents = NULL;
int len;
YBCStatus status;

status = YBCFetchFromUrl(url, &url_contents);
if (status) /* !ok */
{
ereport(LOG,
(errmsg("Fetching from JWT_JWKS_URL failed with error: %s",
YBCStatusMessageBegin(status))));
YBCFreeStatus(status);
return NULL;
}
if(!url_contents)
return NULL;

len = strlen(url_contents);
if(!pg_verifymbstr(url_contents, len, true))
{
ereport(LOG,
(errcode(ERRCODE_CHARACTER_NOT_IN_REPERTOIRE),
errmsg("invalid encoding of contents at \"%s\"", url)));
return NULL;
}
return url_contents;
}
26 changes: 23 additions & 3 deletions src/postgres/src/backend/libpq/hba.c
Original file line number Diff line number Diff line change
Expand Up @@ -1704,9 +1704,18 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
return NULL;
}

if (parsedline->auth_method == uaYbJWT) {
MANDATORY_AUTH_ARG(parsedline->yb_jwt_jwks_path, "jwt_jwks_path",
"jwt");
if (parsedline->auth_method == uaYbJWT)
{
if(!(parsedline->yb_jwt_jwks_url || parsedline->yb_jwt_jwks_path))
{
ereport(elevel,
(errcode(ERRCODE_CONFIG_FILE_ERROR),
errmsg("atleast one of jwt_jwks_url or jwt_jwks_path must be given"),
errcontext("line %d of configuration file \"%s\"",
line_num, HbaFileName)));
*err_msg = "atleast one of jwt_jwks_url or jwt_jwks_path must be given";
return NULL;
}

if (list_length(parsedline->yb_jwt_audiences) < 1)
{
Expand Down Expand Up @@ -2181,6 +2190,12 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,

hbaline->yb_jwt_jwks_path = pstrdup(val);
}
else if (strcmp(name, "jwt_jwks_url") == 0)
{
REQUIRE_AUTH_OPTION(uaYbJWT, "jwt_jwks_url", "jwt");

hbaline->yb_jwt_jwks_url = pstrdup(val);
}
else if (strcmp(name, "jwt_audiences") == 0)
{
List *parsed_audiences;
Expand Down Expand Up @@ -2585,6 +2600,11 @@ gethba_options(HbaLine *hba)
options[noptions++] =
CStringGetTextDatum(psprintf("jwt_jwks_path=%s",
hba->yb_jwt_jwks_path));

if (hba->yb_jwt_jwks_url)
options[noptions++] =
CStringGetTextDatum(psprintf("jwt_jwks_url=%s",
hba->yb_jwt_jwks_url));

if (hba->yb_jwt_audiences_s)
options[noptions++] =
Expand Down
1 change: 1 addition & 0 deletions src/postgres/src/include/libpq/hba.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ typedef struct HbaLine
char *radiusports_s;

char *yb_jwt_jwks_path;
char *yb_jwt_jwks_url;
List *yb_jwt_audiences;
char *yb_jwt_audiences_s;
List *yb_jwt_issuers;
Expand Down
15 changes: 15 additions & 0 deletions src/yb/yql/pggate/ybc_pggate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "yb/server/skewed_clock.h"

#include "yb/util/atomic.h"
#include "yb/util/curl_util.h"
#include "yb/util/flags.h"
#include "yb/util/jwt_util.h"
#include "yb/util/result.h"
Expand Down Expand Up @@ -499,6 +500,20 @@ YBCStatus YBCValidateJWT(const char *token, const YBCPgJwtAuthOptions *options)
return ToYBCStatus(STATUS(InvalidArgument, "Identity match failed"));
}

YBCStatus YBCFetchFromUrl(const char *url, char **buf) {
const std::string url_value(DCHECK_NOTNULL(url));
EasyCurl curl;
faststring buf_ret;
auto status = curl.FetchURL(url_value, &buf_ret);
if (!status.ok()) {
return ToYBCStatus(status);
}

*DCHECK_NOTNULL(buf) = static_cast<char*>(YBCPAlloc(buf_ret.size()+1));
snprintf(*buf, buf_ret.size()+1, "%s", buf_ret.ToString().c_str());
return YBCStatusOK();
}

bool YBCGetCurrentPgSessionParallelData(YBCPgSessionParallelData* session_data) {
if (pgapi) {
session_data->session_id = pgapi->GetSessionId();
Expand Down
1 change: 1 addition & 0 deletions src/yb/yql/pggate/ybc_pggate.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ YBCStatus YBCGetHeapConsumption(YbTcmallocStats *desc);

// Validate the JWT based on the options including the identity matching based on the identity map.
YBCStatus YBCValidateJWT(const char *token, const YBCPgJwtAuthOptions *options);
YBCStatus YBCFetchFromUrl(const char *url, char **buf);

// Is this node acting as the pg_cron leader?
bool YBCIsCronLeader();
Expand Down

0 comments on commit 18bb9b8

Please sign in to comment.