Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend support for Express middlewares to include error handling #674

Merged
merged 3 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions lib/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export interface ServerOptions {
type Middleware = (
req: IncomingMessage,
res: ServerResponse,
next: () => void
next: (err?: any) => void
) => void;

export abstract class BaseServer extends EventEmitter {
Expand Down Expand Up @@ -335,7 +335,7 @@ export abstract class BaseServer extends EventEmitter {
protected _applyMiddlewares(
req: IncomingMessage,
res: ServerResponse,
callback: () => void
callback: (err?: any) => void
) {
if (this.middlewares.length === 0) {
debug("no middleware to apply, skipping");
Expand All @@ -344,7 +344,11 @@ export abstract class BaseServer extends EventEmitter {

const apply = (i) => {
debug("applying middleware n°%d", i + 1);
this.middlewares[i](req, res, () => {
this.middlewares[i](req, res, (err?: any) => {
if (err) {
return callback(err);
}

if (i + 1 < this.middlewares.length) {
apply(i + 1);
} else {
Expand Down Expand Up @@ -655,8 +659,12 @@ export class Server extends BaseServer {
}
};

this._applyMiddlewares(req, res, () => {
this.verify(req, false, callback);
this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, false, callback);
}
});
}

Expand All @@ -673,32 +681,37 @@ export class Server extends BaseServer {
this.prepare(req);

const res = new WebSocketResponse(req, socket);
const callback = (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
abortUpgrade(socket, errorCode, errorContext);
return;
}

this._applyMiddlewares(req, res as unknown as ServerResponse, () => {
this.verify(req, true, (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
abortUpgrade(socket, errorCode, errorContext);
return;
}

const head = Buffer.from(upgradeHead);
upgradeHead = null;
const head = Buffer.from(upgradeHead);
upgradeHead = null;

// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
res.writeHead();
// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
res.writeHead();

// delegate to ws
this.ws.handleUpgrade(req, socket, head, (websocket) => {
this.onWebSocket(req, socket, websocket);
});
// delegate to ws
this.ws.handleUpgrade(req, socket, head, (websocket) => {
this.onWebSocket(req, socket, websocket);
});
};

this._applyMiddlewares(req, res as unknown as ServerResponse, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, true, callback);
}
});
}

Expand Down
170 changes: 93 additions & 77 deletions lib/userver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,24 @@ export class uServer extends BaseServer {
});
}

override _applyMiddlewares(req: any, res: any, callback: () => void): void {
override _applyMiddlewares(
req: any,
res: any,
callback: (err?: any) => void
): void {
if (this.middlewares.length === 0) {
return callback();
}

// needed to buffer headers until the status is computed
req.res = new ResponseWrapper(res);

super._applyMiddlewares(req, req.res, () => {
super._applyMiddlewares(req, req.res, (err) => {
// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
req.res.writeHead();

callback();
callback(err);
});
}

Expand All @@ -118,28 +122,34 @@ export class uServer extends BaseServer {

req.res = res;

this._applyMiddlewares(req, res, () => {
this.verify(req, false, (errorCode, errorContext) => {
if (errorCode !== undefined) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(req.res, errorCode, errorContext);
return;
}
const callback = (errorCode, errorContext) => {
if (errorCode !== undefined) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(req.res, errorCode, errorContext);
return;
}

if (req._query.sid) {
debug("setting new request for existing client");
this.clients[req._query.sid].transport.onRequest(req);
} else {
const closeConnection = (errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext);
this.handshake(req._query.transport, req, closeConnection);
}
};

if (req._query.sid) {
debug("setting new request for existing client");
this.clients[req._query.sid].transport.onRequest(req);
} else {
const closeConnection = (errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext);
this.handshake(req._query.transport, req, closeConnection);
}
});
this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, false, callback);
}
});
}

Expand All @@ -154,63 +164,69 @@ export class uServer extends BaseServer {

req.res = res;

this._applyMiddlewares(req, res, () => {
this.verify(req, true, async (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(res, errorCode, errorContext);
const callback = async (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(res, errorCode, errorContext);
return;
}

const id = req._query.sid;
let transport;

if (id) {
const client = this.clients[id];
if (!client) {
debug("upgrade attempt for closed client");
res.close();
} else if (client.upgrading) {
debug("transport has already been trying to upgrade");
res.close();
} else if (client.upgraded) {
debug("transport had already been upgraded");
res.close();
} else {
debug("upgrading existing transport");
transport = this.createTransport(req._query.transport, req);
client.maybeUpgrade(transport);
}
} else {
transport = await this.handshake(
req._query.transport,
req,
(errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext)
);
if (!transport) {
return;
}
}

const id = req._query.sid;
let transport;

if (id) {
const client = this.clients[id];
if (!client) {
debug("upgrade attempt for closed client");
res.close();
} else if (client.upgrading) {
debug("transport has already been trying to upgrade");
res.close();
} else if (client.upgraded) {
debug("transport had already been upgraded");
res.close();
} else {
debug("upgrading existing transport");
transport = this.createTransport(req._query.transport, req);
client.maybeUpgrade(transport);
}
} else {
transport = await this.handshake(
req._query.transport,
req,
(errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext)
);
if (!transport) {
return;
}
}
// calling writeStatus() triggers the flushing of any header added in a middleware
req.res.writeStatus("101 Switching Protocols");

// calling writeStatus() triggers the flushing of any header added in a middleware
req.res.writeStatus("101 Switching Protocols");

res.upgrade(
{
transport,
},
req.getHeader("sec-websocket-key"),
req.getHeader("sec-websocket-protocol"),
req.getHeader("sec-websocket-extensions"),
context
);
});
res.upgrade(
{
transport,
},
req.getHeader("sec-websocket-key"),
req.getHeader("sec-websocket-protocol"),
req.getHeader("sec-websocket-extensions"),
context
);
};

this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, true, callback);
}
});
}

Expand Down
44 changes: 44 additions & 0 deletions test/middlewares.js
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,48 @@ describe("middlewares", () => {
});
});
});

it("should fail on errors (polling)", (done) => {
const engine = listen((port) => {
engine.use((req, res, next) => {
next(new Error("will always fail"));
});

request
.get(`http://localhost:${port}/engine.io/`)
.query({ EIO: 4, transport: "polling" })
.end((err, res) => {
expect(err).to.be.an(Error);
expect(res.status).to.eql(400);

if (engine.httpServer) {
engine.httpServer.close();
}
done();
});
});

it("should fail on errors (websocket)", (done) => {
const engine = listen((port) => {
engine.use((req, res, next) => {
next(new Error("will always fail"));
});

engine.on("connection", () => {
done(new Error("should not connect"));
});

const socket = new WebSocket(
`ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`
);

socket.addEventListener("error", () => {
if (engine.httpServer) {
engine.httpServer.close();
}
done();
});
});
});
});
});