Skip to content

Commit

Permalink
fix: include error handling for Express middlewares (#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
cieldeville authored May 1, 2023
1 parent 911d0e3 commit 9395782
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 104 deletions.
67 changes: 40 additions & 27 deletions lib/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export interface ServerOptions {
type Middleware = (
req: IncomingMessage,
res: ServerResponse,
next: () => void
next: (err?: any) => void
) => void;

export abstract class BaseServer extends EventEmitter {
Expand Down Expand Up @@ -335,7 +335,7 @@ export abstract class BaseServer extends EventEmitter {
protected _applyMiddlewares(
req: IncomingMessage,
res: ServerResponse,
callback: () => void
callback: (err?: any) => void
) {
if (this.middlewares.length === 0) {
debug("no middleware to apply, skipping");
Expand All @@ -344,7 +344,11 @@ export abstract class BaseServer extends EventEmitter {

const apply = (i) => {
debug("applying middleware n°%d", i + 1);
this.middlewares[i](req, res, () => {
this.middlewares[i](req, res, (err?: any) => {
if (err) {
return callback(err);
}

if (i + 1 < this.middlewares.length) {
apply(i + 1);
} else {
Expand Down Expand Up @@ -655,8 +659,12 @@ export class Server extends BaseServer {
}
};

this._applyMiddlewares(req, res, () => {
this.verify(req, false, callback);
this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, false, callback);
}
});
}

Expand All @@ -673,32 +681,37 @@ export class Server extends BaseServer {
this.prepare(req);

const res = new WebSocketResponse(req, socket);
const callback = (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
abortUpgrade(socket, errorCode, errorContext);
return;
}

this._applyMiddlewares(req, res as unknown as ServerResponse, () => {
this.verify(req, true, (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
abortUpgrade(socket, errorCode, errorContext);
return;
}

const head = Buffer.from(upgradeHead);
upgradeHead = null;
const head = Buffer.from(upgradeHead);
upgradeHead = null;

// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
res.writeHead();
// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
res.writeHead();

// delegate to ws
this.ws.handleUpgrade(req, socket, head, (websocket) => {
this.onWebSocket(req, socket, websocket);
});
// delegate to ws
this.ws.handleUpgrade(req, socket, head, (websocket) => {
this.onWebSocket(req, socket, websocket);
});
};

this._applyMiddlewares(req, res as unknown as ServerResponse, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, true, callback);
}
});
}

Expand Down
170 changes: 93 additions & 77 deletions lib/userver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,24 @@ export class uServer extends BaseServer {
});
}

override _applyMiddlewares(req: any, res: any, callback: () => void): void {
override _applyMiddlewares(
req: any,
res: any,
callback: (err?: any) => void
): void {
if (this.middlewares.length === 0) {
return callback();
}

// needed to buffer headers until the status is computed
req.res = new ResponseWrapper(res);

super._applyMiddlewares(req, req.res, () => {
super._applyMiddlewares(req, req.res, (err) => {
// some middlewares (like express-session) wait for the writeHead() call to flush their headers
// see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244
req.res.writeHead();

callback();
callback(err);
});
}

Expand All @@ -118,28 +122,34 @@ export class uServer extends BaseServer {

req.res = res;

this._applyMiddlewares(req, res, () => {
this.verify(req, false, (errorCode, errorContext) => {
if (errorCode !== undefined) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(req.res, errorCode, errorContext);
return;
}
const callback = (errorCode, errorContext) => {
if (errorCode !== undefined) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(req.res, errorCode, errorContext);
return;
}

if (req._query.sid) {
debug("setting new request for existing client");
this.clients[req._query.sid].transport.onRequest(req);
} else {
const closeConnection = (errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext);
this.handshake(req._query.transport, req, closeConnection);
}
};

if (req._query.sid) {
debug("setting new request for existing client");
this.clients[req._query.sid].transport.onRequest(req);
} else {
const closeConnection = (errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext);
this.handshake(req._query.transport, req, closeConnection);
}
});
this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, false, callback);
}
});
}

Expand All @@ -154,63 +164,69 @@ export class uServer extends BaseServer {

req.res = res;

this._applyMiddlewares(req, res, () => {
this.verify(req, true, async (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(res, errorCode, errorContext);
const callback = async (errorCode, errorContext) => {
if (errorCode) {
this.emit("connection_error", {
req,
code: errorCode,
message: Server.errorMessages[errorCode],
context: errorContext,
});
this.abortRequest(res, errorCode, errorContext);
return;
}

const id = req._query.sid;
let transport;

if (id) {
const client = this.clients[id];
if (!client) {
debug("upgrade attempt for closed client");
res.close();
} else if (client.upgrading) {
debug("transport has already been trying to upgrade");
res.close();
} else if (client.upgraded) {
debug("transport had already been upgraded");
res.close();
} else {
debug("upgrading existing transport");
transport = this.createTransport(req._query.transport, req);
client.maybeUpgrade(transport);
}
} else {
transport = await this.handshake(
req._query.transport,
req,
(errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext)
);
if (!transport) {
return;
}
}

const id = req._query.sid;
let transport;

if (id) {
const client = this.clients[id];
if (!client) {
debug("upgrade attempt for closed client");
res.close();
} else if (client.upgrading) {
debug("transport has already been trying to upgrade");
res.close();
} else if (client.upgraded) {
debug("transport had already been upgraded");
res.close();
} else {
debug("upgrading existing transport");
transport = this.createTransport(req._query.transport, req);
client.maybeUpgrade(transport);
}
} else {
transport = await this.handshake(
req._query.transport,
req,
(errorCode, errorContext) =>
this.abortRequest(res, errorCode, errorContext)
);
if (!transport) {
return;
}
}
// calling writeStatus() triggers the flushing of any header added in a middleware
req.res.writeStatus("101 Switching Protocols");

// calling writeStatus() triggers the flushing of any header added in a middleware
req.res.writeStatus("101 Switching Protocols");

res.upgrade(
{
transport,
},
req.getHeader("sec-websocket-key"),
req.getHeader("sec-websocket-protocol"),
req.getHeader("sec-websocket-extensions"),
context
);
});
res.upgrade(
{
transport,
},
req.getHeader("sec-websocket-key"),
req.getHeader("sec-websocket-protocol"),
req.getHeader("sec-websocket-extensions"),
context
);
};

this._applyMiddlewares(req, res, (err) => {
if (err) {
callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" });
} else {
this.verify(req, true, callback);
}
});
}

Expand Down
44 changes: 44 additions & 0 deletions test/middlewares.js
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,48 @@ describe("middlewares", () => {
});
});
});

it("should fail on errors (polling)", (done) => {
const engine = listen((port) => {
engine.use((req, res, next) => {
next(new Error("will always fail"));
});

request
.get(`http://localhost:${port}/engine.io/`)
.query({ EIO: 4, transport: "polling" })
.end((err, res) => {
expect(err).to.be.an(Error);
expect(res.status).to.eql(400);

if (engine.httpServer) {
engine.httpServer.close();
}
done();
});
});

it("should fail on errors (websocket)", (done) => {
const engine = listen((port) => {
engine.use((req, res, next) => {
next(new Error("will always fail"));
});

engine.on("connection", () => {
done(new Error("should not connect"));
});

const socket = new WebSocket(
`ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`
);

socket.addEventListener("error", () => {
if (engine.httpServer) {
engine.httpServer.close();
}
done();
});
});
});
});
});

0 comments on commit 9395782

Please sign in to comment.