Skip to content

Commit

Permalink
Improve streaming server security (mastodon#10818)
Browse files Browse the repository at this point in the history
* Check OAuth token scopes in the streaming API

* Use Sec-WebSocket-Protocol instead of query string to pass WebSocket token

Inspired by kubevirt/kubevirt#1242
  • Loading branch information
ClearlyClaire authored and multiple creatures committed Nov 19, 2019
1 parent 7952281 commit 64a68bf
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 20 deletions.
6 changes: 1 addition & 5 deletions app/javascript/mastodon/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ export function connectStream(path, pollingRefresh = null, callbacks = () => ({
export default function getStream(streamingAPIBaseURL, accessToken, stream, { connected, received, disconnected, reconnected }) {
const params = [ `stream=${stream}` ];

if (accessToken !== null) {
params.push(`access_token=${accessToken}`);
}

const ws = new WebSocketClient(`${streamingAPIBaseURL}/api/v1/streaming/?${params.join('&')}`);
const ws = new WebSocketClient(`${streamingAPIBaseURL}/api/v1/streaming/?${params.join('&')}`, accessToken);

ws.onopen = connected;
ws.onmessage = e => received(JSON.parse(e.data));
Expand Down
70 changes: 55 additions & 15 deletions streaming/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ const startWorker = (workerId) => {
next();
};

const accountFromToken = (token, req, next) => {
const accountFromToken = (token, allowedScopes, req, next) => {
pgPool.connect((err, client, done) => {
if (err) {
next(err);
return;
}

client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
done();

if (err) {
Expand All @@ -218,18 +218,29 @@ const startWorker = (workerId) => {
return;
}

const scopes = result.rows[0].scopes.split(' ');

if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) {
err = new Error('Access token does not cover required scopes');
err.statusCode = 401;

next(err);
return;
}

req.accountId = result.rows[0].account_id;
req.chosenLanguages = result.rows[0].chosen_languages;
req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope));

next();
});
});
};

const accountFromRequest = (req, next, required = true) => {
const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => {
const authorization = req.headers.authorization;
const location = url.parse(req.url, true);
const accessToken = location.query.access_token;
const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];

if (!authorization && !accessToken) {
if (required) {
Expand All @@ -246,17 +257,29 @@ const startWorker = (workerId) => {

const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;

accountFromToken(token, req, next);
accountFromToken(token, allowedScopes, req, next);
};

const PUBLIC_STREAMS = [
'public',
'public:media',
'hashtag',
'hashtag:local',
];

const wsVerifyClient = (info, cb) => {
const location = url.parse(info.req.url, true);
const authRequired = !PUBLIC_STREAMS.some(stream => stream === location.query.stream);
const allowedScopes = [];

if (authRequired) {
allowedScopes.push('read');
if (location.query.stream === 'user:notification') {
allowedScopes.push('read:notifications');
} else {
allowedScopes.push('read:statuses');
}
}

accountFromRequest(info.req, err => {
if (!err) {
Expand All @@ -265,10 +288,11 @@ const startWorker = (workerId) => {
log.error(info.req.requestId, err.toString());
cb(false, 401, 'Unauthorized');
}
}, authRequired);
}, authRequired, allowedScopes);
};

const PUBLIC_ENDPOINTS = [
'/api/v1/streaming/public',
'/api/v1/streaming/hashtag',
];

Expand All @@ -279,7 +303,18 @@ const startWorker = (workerId) => {
}

const authRequired = !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path);
accountFromRequest(req, next, authRequired);
const allowedScopes = [];

if (authRequired) {
allowedScopes.push('read');
if (req.path === '/api/v1/streaming/user/notification') {
allowedScopes.push('read:notifications');
} else {
allowedScopes.push('read:statuses');
}
}

accountFromRequest(req, next, authRequired, allowedScopes);
};

const errorMiddleware = (err, req, res, {}) => {
Expand Down Expand Up @@ -328,12 +363,11 @@ const startWorker = (workerId) => {
output(event, encodedPayload);
};

if (!req.accountId) {
log.error(req.requestId, `Unauthorized: ${accountId} is not logged in.`)
if (notificationOnly && event !== 'notification') {
return;
}

if (notificationOnly && event !== 'notification') {
if (event === 'notification' && !req.allowNotifications) {
return;
}

Expand All @@ -359,17 +393,23 @@ const startWorker = (workerId) => {
const targetAccountIds = [unpackedPayload.account.id].concat(unpackedPayload.mentions.map(item => item.id));
const accountDomain = unpackedPayload.account.acct.split('@')[1];

if (Array.isArray(req.chosenLanguages) && unpackedPayload.language !== null && req.chosenLanguages.indexOf(unpackedPayload.language) === -1) {
log.silly(req.requestId, `Message ${unpackedPayload.id} filtered by language (${unpackedPayload.language})`);
return;
}

// When the account is not logged in, it is not necessary to confirm the block or mute
if (!req.accountId) {
transmit();
return;
}

// Don't filter user's own events.
if (req.accountId === unpackedPayload.account.id) {
transmit();
return
}

if (Array.isArray(req.chosenLanguages) && unpackedPayload.language !== null && req.chosenLanguages.indexOf(unpackedPayload.language) === -1) {
log.silly(req.requestId, `Message ${unpackedPayload.id} filtered by language (${unpackedPayload.language})`);
return;
}

pgPool.connect((err, client, done) => {
if (err) {
log.error(err);
Expand Down

0 comments on commit 64a68bf

Please sign in to comment.