Skip to content

Commit

Permalink
Improve rpc method validation (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
macjuul authored Sep 4, 2024
1 parent fba382a commit 48c648f
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
49 changes: 37 additions & 12 deletions src/engines/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ import {
type EngineEvents,
} from "./abstract";

const ALWAYS_ALLOW = new Set([
"signin",
"signup",
"authenticate",
"invalidate",
"version",
"use",
"let",
"unset",
"query",
]);

export class HttpEngine extends AbstractEngine {
connection: {
url: URL | undefined;
Expand Down Expand Up @@ -67,10 +79,18 @@ export class HttpEngine extends AbstractEngine {
Result,
>(request: RpcRequest<Method, Params>): Promise<RpcResponse<Result>> {
await this.ready;

if (!this.connection.url) {
throw new ConnectionUnavailable();
}

if (
(!this.connection.namespace || !this.connection.database) &&
!ALWAYS_ALLOW.has(request.method)
) {
throw new MissingNamespaceDatabase();
}

if (request.method === "use") {
const [ns, db] = request.params as [
string | null | undefined,
Expand Down Expand Up @@ -112,22 +132,27 @@ export class HttpEngine extends AbstractEngine {
] as Params;
}

if (!this.connection.namespace || !this.connection.database) {
throw new MissingNamespaceDatabase();
const id = getIncrementalID();
const headers: Record<string, string> = {
"Content-Type": "application/cbor",
Accept: "application/cbor",
};

if (this.connection.namespace) {
headers["Surreal-NS"] = this.connection.namespace;
}

if (this.connection.database) {
headers["Surreal-DB"] = this.connection.database;
}

if (this.connection.token) {
headers.Authorization = `Bearer ${this.connection.token}`;
}

const id = getIncrementalID();
const raw = await fetch(`${this.connection.url}`, {
method: "POST",
headers: {
"Content-Type": "application/cbor",
Accept: "application/cbor",
"Surreal-NS": this.connection.namespace,
"Surreal-DB": this.connection.database,
...(this.connection.token
? { Authorization: `Bearer ${this.connection.token}` }
: {}),
},
headers,
body: this.encodeCbor({ id, ...request }),
});

Expand Down
2 changes: 1 addition & 1 deletion src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export class ConnectionUnavailable extends SurrealDbError {

export class MissingNamespaceDatabase extends SurrealDbError {
name = "MissingNamespaceDatabase";
message = "There are no namespace and/or database configured.";
message = "There is no namespace and/or database selected.";
}

export class HttpConnectionError extends SurrealDbError {
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/surreal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type CreateSurrealOptions = {
protocol?: Protocol;
auth?: PremadeAuth;
reachable?: boolean;
unselected?: boolean;
};

export async function setupServer(): Promise<{
Expand All @@ -61,13 +62,14 @@ export async function setupServer(): Promise<{
protocol,
auth,
reachable,
unselected,
}: CreateSurrealOptions = {}) {
protocol = protocol ? protocol : PROTOCOL;
const surreal = new Surreal();
const port = reachable === false ? SURREAL_PORT_UNREACHABLE : SURREAL_PORT;
await surreal.connect(`${protocol}://127.0.0.1:${port}/rpc`, {
namespace: SURREAL_NS,
database: SURREAL_DB,
namespace: unselected ? undefined : SURREAL_NS,
database: unselected ? undefined : SURREAL_DB,
auth: createAuth(auth ?? "root"),
});

Expand Down
23 changes: 23 additions & 0 deletions tests/integration/tests/connection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,26 @@ describe("version check", async () => {
expect(diff).toBeLessThanOrEqual(defaultVersionCheckTimeout + 100); // 100ms margin
});
});

describe("rpc", async () => {
test("allowed rpcs without namespace or database", async () => {
const surreal = await createSurreal({
unselected: true,
protocol: "http",
});

await surreal.version();
await surreal.invalidate();
});

test("disallowed rpcs without namespace or database", async () => {
const surreal = await createSurreal({
unselected: true,
protocol: "http",
});

expect(async () => {
await surreal.query("SELECT * FROM test");
}).toThrow();
});
});

0 comments on commit 48c648f

Please sign in to comment.