Skip to content

Commit

Permalink
ugh just use an explicit mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
dmihalcik-virtru committed Nov 18, 2024
1 parent 13c7a74 commit 50aa168
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 76 deletions.
19 changes: 14 additions & 5 deletions lib/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"watch": "(trap 'kill 0' SIGINT; npm run build && (npm run build:watch & npm run test -- --watch))"
},
"dependencies": {
"async-mutex": "^0.5.0",
"axios": "^1.6.1",
"axios-retry": "^3.9.0",
"base64-js": "^1.5.1",
Expand Down
117 changes: 56 additions & 61 deletions lib/src/auth/oidc.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { default as dpopFn } from 'dpop';
import { Mutex } from 'async-mutex';
import { HttpRequest, withHeaders } from './auth.js';
import { base64 } from '../encodings/index.js';
import { ConfigurationError, TdfError } from '../errors.js';
Expand Down Expand Up @@ -57,10 +58,7 @@ const qstringify = (obj: Record<string, string>) => new URLSearchParams(obj).toS
export type AccessTokenResponse = {
access_token: string;
refresh_token?: string;
};

export type TimeStamp = {
when: number;
timestamp?: number;
};

/**
Expand Down Expand Up @@ -88,10 +86,9 @@ export type TimeStamp = {
export class AccessToken {
config: OIDCCredentials;

// For mocking fetch
request?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;

data?: Promise<AccessTokenResponse & TimeStamp>;
data?: AccessTokenResponse;

baseUrl: string;

Expand All @@ -101,6 +98,8 @@ export class AccessToken {

currentAccessToken?: string;

mutex: Mutex = new Mutex();

constructor(cfg: OIDCCredentials, request?: typeof fetch) {
if (!cfg.clientId) {
throw new ConfigurationError(
Expand Down Expand Up @@ -175,7 +174,7 @@ export class AccessToken {
});
}

async accessTokenLookup(cfg: OIDCCredentials): Promise<AccessTokenResponse & TimeStamp> {
async accessTokenLookup(cfg: OIDCCredentials) {
const url = `${this.baseUrl}/protocol/openid-connect/token`;
let body;
switch (cfg.exchange) {
Expand Down Expand Up @@ -211,7 +210,7 @@ export class AccessToken {
);
}
const r = await response.json();
r.when = Date.now();
r.timestamp = Date.now();
return r;
}

Expand All @@ -221,50 +220,36 @@ export class AccessToken {
* @returns
*/
async get(validate?: boolean): Promise<string> {
let isNew = false;
const now = Date.now();
let currentData = this.data;
if (!currentData) {
currentData = this.accessTokenLookup(this.config);
this.data = currentData;
isNew = true;
}
let tokenResponse: AccessTokenResponse & TimeStamp;
const release = await this.mutex.acquire();
try {
tokenResponse = await currentData;
} catch (e) {
// Failed during token exchange.
if (this.data === currentData) {
delete this.data;
}
throw e;
}
if (isNew) {
// If we just did the first token exchange, we may have a refresh token.
if (tokenResponse.refresh_token) {
// Upgrade to refresh token type, if we have one
this.config = {
...this.config,
exchange: 'refresh',
refreshToken: tokenResponse.refresh_token,
};
}
return tokenResponse.access_token;
}

// Validate if explicitly requested or, if not defined, when the token is older than 5 minutes.
if (!!validate || (validate === undefined && now - tokenResponse.when > 1000 * 60 * 5)) {
try {
await this.info(tokenResponse.access_token);
} catch (e) {
console.log('access_token fails on user_info endpoint; attempting to renew', e);
if (this.data === currentData) {
if (this.data?.access_token) {
try {
// Validation was explicitly requested OR the token is older than 5 minutes.
if (validate || (this.data.timestamp && this.data.timestamp + 1000 * 60 * 5 < now)) {
await this.info(this.data.access_token);
}
return this.data.access_token;
} catch (e) {
console.log('access_token fails on user_info endpoint; attempting to renew', e);
if (this.data.refresh_token) {
// Prefer the latest refresh_token if present over creds passed in
// to constructor
this.config = {
...this.config,
exchange: 'refresh',
refreshToken: this.data.refresh_token,
};
}
delete this.data;
}
return this.get(false);
}

const tokenResponse = (this.data = await this.accessTokenLookup(this.config));
return tokenResponse.access_token;
} finally {
release();
}
return tokenResponse.access_token;
}

/**
Expand All @@ -288,22 +273,32 @@ export class AccessToken {
/**
* Converts included refresh token or external JWT for a new one.
*/
async exchangeForRefreshToken(): Promise<void> {
const cfg = this.config;
if (cfg.exchange != 'external' && cfg.exchange != 'refresh') {
throw new ConfigurationError('no refresh token provided!');
}
const tokenResponse = await (this.data = this.accessTokenLookup(this.config));
if (!tokenResponse.refresh_token) {
return;
async exchangeForRefreshToken(): Promise<string> {
const release = await this.mutex.acquire();
try {
const cfg = this.config;
if (cfg.exchange != 'external' && cfg.exchange != 'refresh') {
throw new ConfigurationError('no refresh token provided!');
}
const tokenResponse = (this.data = await this.accessTokenLookup(this.config));
if (!tokenResponse.refresh_token) {
return (
(cfg.exchange == 'refresh' && cfg.refreshToken) ||
(cfg.exchange == 'external' && cfg.externalJwt) ||
''
);
}
// Prefer the latest refresh_token if present over creds passed in
// to constructor
this.config = {
...this.config,
exchange: 'refresh',
refreshToken: tokenResponse.refresh_token,
};
return tokenResponse.access_token;
} finally {
release();
}
// Prefer the latest refresh_token if present over creds passed in
// to constructor, for token exchange. Refresh tokens usually stay the same.
this.config = {
...this.config,
exchange: 'refresh',
refreshToken: tokenResponse.refresh_token,
};
}

async withCreds(httpReq: HttpRequest): Promise<HttpRequest> {
Expand Down
34 changes: 24 additions & 10 deletions lib/tests/web/auth/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,10 @@ describe('AccessToken', () => {
},
mf
);
accessTokenClient.data = Promise.resolve({
accessTokenClient.data = {
access_token: 'a',
refresh_token: 'r',
when: Date.now(),
});
};
// Do a refresh to cache tokenset
const res = await accessTokenClient.get();
expect(res).to.eql('a');
Expand All @@ -310,15 +309,30 @@ describe('AccessToken', () => {
});
it('should attempt to refresh token if userinfo call throws error', async () => {
const signingKey = await crypto.subtle.generateKey(algorithmSigner, true, ['sign']);
const json = fake.resolves({ access_token: 'a' });
const json = fake.resolves({ access_token: 'A' });
const mf = fake((url: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
if (!init) {
return Promise.reject('No init found');
}
if (init.method === 'POST') {
if (typeof (url as unknown) == 'string') {
url = new URL(url as string);
} else if (!(url instanceof URL)) {
return Promise.reject('url is not a string or URL');
}
if (init.method === 'POST' && url.pathname.endsWith('token')) {
return Promise.resolve({ ...ok, json });
}
return Promise.reject(`yee [${url}] [${JSON.stringify(init.headers)}]`);
if (url.pathname.endsWith('userinfo')) {
if (
!init.headers ||
!('Authorization' in init.headers) ||
init.headers.Authorization !== 'Bearer A'
) {
return Promise.reject('yee');
}
return Promise.resolve({ ...ok, json: fake.resolves({ email: 'user@some.where' }) });
}
return Promise.reject('zee');
});
const accessTokenClient = new AccessToken(
{
Expand All @@ -330,11 +344,11 @@ describe('AccessToken', () => {
},
mf
);
accessTokenClient.data = Promise.resolve({
accessTokenClient.data = {
access_token: 'a',
refresh_token: 'r',
when: Date.now() - 10 * 60 * 1000,
});
timestamp: Date.now() - 1000 * 60 * 60,
};
// Do a refresh to cache tokenset
const res = await Promise.all([
accessTokenClient.get(),
Expand All @@ -343,7 +357,7 @@ describe('AccessToken', () => {
accessTokenClient.get(),
accessTokenClient.get(),
]);
expect(res).to.include('a');
expect(res).to.include('A');
expect(mf.callCount).to.eql(2);
});
});
Expand Down

0 comments on commit 50aa168

Please sign in to comment.