Skip to content

Commit

Permalink
feat: add support for WebTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
darrachequesne committed Jun 11, 2023
1 parent 3144d27 commit 123b68c
Show file tree
Hide file tree
Showing 11 changed files with 1,853 additions and 35 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ jobs:
strategy:
matrix:
node-version:
- 10
- 18

steps:
Expand Down
108 changes: 105 additions & 3 deletions lib/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ import type {
import type { CookieSerializeOptions } from "cookie";
import type { CorsOptions, CorsOptionsDelegate } from "cors";
import type { Duplex } from "stream";
import { WebTransport } from "./transports/webtransport";

const debug = debugModule("engine");

const kResponseHeaders = Symbol("responseHeaders");
const TEXT_DECODER = new TextDecoder();

type Transport = "polling" | "websocket";

Expand Down Expand Up @@ -78,7 +80,13 @@ export interface ServerOptions {
fn: (err: string | null | undefined, success: boolean) => void
) => void;
/**
* the low-level transports that are enabled
* The low-level transports that are enabled. WebTransport is disabled by default and must be manually enabled:
*
* @example
* new Server({
* transports: ["polling", "websocket", "webtransport"]
* });
*
* @default ["polling", "websocket"]
*/
transports?: Transport[];
Expand Down Expand Up @@ -140,6 +148,17 @@ type Middleware = (
next: (err?: any) => void
) => void;

function parseSessionId(handshake: string) {
if (handshake.startsWith("0{")) {
try {
const parsed = JSON.parse(handshake.substring(1));
if (typeof parsed.sid === "string") {
return parsed.sid;
}
} catch (e) {}
}
}

export abstract class BaseServer extends EventEmitter {
public opts: ServerOptions;

Expand All @@ -166,7 +185,7 @@ export abstract class BaseServer extends EventEmitter {
pingInterval: 25000,
upgradeTimeout: 10000,
maxHttpBufferSize: 1e6,
transports: Object.keys(transports),
transports: ["polling", "websocket"], // WebTransport is disabled by default
allowUpgrades: true,
httpCompression: {
threshold: 1024,
Expand Down Expand Up @@ -245,7 +264,11 @@ export abstract class BaseServer extends EventEmitter {
protected verify(req, upgrade, fn) {
// transport check
const transport = req._query.transport;
if (!~this.opts.transports.indexOf(transport)) {
// WebTransport does not go through the verify() method, see the onWebTransportSession() method
if (
!~this.opts.transports.indexOf(transport) ||
transport === "webtransport"
) {
debug('unknown transport "%s"', transport);
return fn(Server.errors.UNKNOWN_TRANSPORT, { transport });
}
Expand Down Expand Up @@ -495,6 +518,85 @@ export abstract class BaseServer extends EventEmitter {
return transport;
}

public async onWebTransportSession(session: any) {
const timeout = setTimeout(() => {
debug(
"the client failed to establish a bidirectional stream in the given period"
);
session.close();
}, this.opts.upgradeTimeout);

const streamReader = session.incomingBidirectionalStreams.getReader();
const result = await streamReader.read();

if (result.done) {
debug("session is closed");
return;
}

const stream = result.value;
const reader = stream.readable.getReader();

// reading the first packet of the stream
const { value, done } = await reader.read();
if (done) {
debug("stream is closed");
return;
}

clearTimeout(timeout);
const handshake = TEXT_DECODER.decode(value);

// handshake is either
// "0" => new session
// '0{"sid":"xxxx"}' => upgrade
if (handshake === "0") {
const transport = new WebTransport(session, stream, reader);

// note: we cannot use "this.generateId()", because there is no "req" argument
const id = base64id.generateId();
debug('handshaking client "%s" (WebTransport)', id);

const socket = new Socket(id, this, transport, null, 4);

this.clients[id] = socket;
this.clientsCount++;

socket.once("close", () => {
delete this.clients[id];
this.clientsCount--;
});

this.emit("connection", socket);
return;
}

const sid = parseSessionId(handshake);

if (!sid) {
debug("invalid WebTransport handshake");
return session.close();
}

const client = this.clients[sid];

if (!client) {
debug("upgrade attempt for closed client");
session.close();
} else if (client.upgrading) {
debug("transport has already been trying to upgrade");
session.close();
} else if (client.upgraded) {
debug("transport had already been upgraded");
session.close();
} else {
debug("upgrading existing transport");

const transport = new WebTransport(session, stream, reader);
client.maybeUpgrade(transport);
}
}

protected abstract createTransport(transportName, req);

/**
Expand Down
11 changes: 8 additions & 3 deletions lib/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ export class Socket extends EventEmitter {
this.protocol = protocol;

// Cache IP since it might not be in the req later
if (req.websocket && req.websocket._socket) {
this.remoteAddress = req.websocket._socket.remoteAddress;
if (req) {
if (req.websocket && req.websocket._socket) {
this.remoteAddress = req.websocket._socket.remoteAddress;
} else {
this.remoteAddress = req.connection.remoteAddress;
}
} else {
this.remoteAddress = req.connection.remoteAddress;
// TODO there is currently no way to get the IP address of the client when it connects with WebTransport
// see https://github.com/fails-components/webtransport/issues/114
}

this.checkIntervalTimer = null;
Expand Down
18 changes: 18 additions & 0 deletions lib/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,26 @@ export abstract class Transport extends EventEmitter {
this.emit("close");
}

/**
* Advertise framing support.
*/
abstract get supportsFraming();

/**
* The name of the transport.
*/
abstract get name();

/**
* Sends an array of packets.
*
* @param {Array} packets
* @package
*/
abstract send(packets);

/**
* Closes the transport.
*/
abstract doClose(fn?);
}
4 changes: 3 additions & 1 deletion lib/transports/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { Polling as XHR } from "./polling";
import { JSONP } from "./polling-jsonp";
import { WebSocket } from "./websocket";
import { WebTransport } from "./webtransport";

export default {
polling: polling,
websocket: WebSocket,
webtransport: WebTransport,
};

/**
Expand All @@ -21,4 +23,4 @@ function polling(req) {
}
}

polling.upgradesTo = ["websocket"];
polling.upgradesTo = ["websocket", "webtransport"];
88 changes: 88 additions & 0 deletions lib/transports/webtransport.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import { Transport } from "../transport";
import debugModule from "debug";

const debug = debugModule("engine:webtransport");

const BINARY_HEADER = Buffer.of(54);

function shouldIncludeBinaryHeader(packet, encoded) {
// 48 === "0".charCodeAt(0) (OPEN packet type)
// 54 === "6".charCodeAt(0) (NOOP packet type)
return (
packet.type === "message" &&
typeof packet.data !== "string" &&
encoded[0] >= 48 &&
encoded[0] <= 54
);
}

/**
* Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebTransport_API
*/
export class WebTransport extends Transport {
private readonly writer;

constructor(private readonly session, stream, reader) {
super({ _query: { EIO: "4" } });
this.writer = stream.writable.getWriter();
(async () => {
let binaryFlag = false;
while (true) {
const { value, done } = await reader.read();
if (done) {
debug("session is closed");
break;
}
debug("received chunk: %o", value);
if (!binaryFlag && value.byteLength === 1 && value[0] === 54) {
binaryFlag = true;
continue;
}
this.onPacket(
this.parser.decodePacketFromBinary(value, binaryFlag, "nodebuffer")
);
binaryFlag = false;
}
})();

session.closed.then(() => this.onClose());

this.writable = true;
}

get name() {
return "webtransport";
}

get supportsFraming() {
return true;
}

send(packets) {
this.writable = false;

for (let i = 0; i < packets.length; i++) {
const packet = packets[i];
const isLast = i + 1 === packets.length;

this.parser.encodePacketToBinary(packet, (data) => {
if (shouldIncludeBinaryHeader(packet, data)) {
debug("writing binary header");
this.writer.write(BINARY_HEADER);
}
debug("writing chunk: %o", data);
this.writer.write(data);
if (isLast) {
this.writable = true;
this.emit("drain");
}
});
}
}

doClose(fn) {
debug("closing WebTransport session");
this.session.close();
fn && fn();
}
}
Loading

0 comments on commit 123b68c

Please sign in to comment.