diff --git a/src/engines/http.ts b/src/engines/http.ts index c611c121..d3aa54f5 100644 --- a/src/engines/http.ts +++ b/src/engines/http.ts @@ -12,6 +12,18 @@ import { type EngineEvents, } from "./abstract"; +const ALWAYS_ALLOW = new Set([ + "signin", + "signup", + "authenticate", + "invalidate", + "version", + "use", + "let", + "unset", + "query", +]); + export class HttpEngine extends AbstractEngine { connection: { url: URL | undefined; @@ -67,10 +79,18 @@ export class HttpEngine extends AbstractEngine { Result, >(request: RpcRequest): Promise> { await this.ready; + if (!this.connection.url) { throw new ConnectionUnavailable(); } + if ( + (!this.connection.namespace || !this.connection.database) && + !ALWAYS_ALLOW.has(request.method) + ) { + throw new MissingNamespaceDatabase(); + } + if (request.method === "use") { const [ns, db] = request.params as [ string | null | undefined, @@ -112,22 +132,27 @@ export class HttpEngine extends AbstractEngine { ] as Params; } - if (!this.connection.namespace || !this.connection.database) { - throw new MissingNamespaceDatabase(); + const id = getIncrementalID(); + const headers: Record = { + "Content-Type": "application/cbor", + Accept: "application/cbor", + }; + + if (this.connection.namespace) { + headers["Surreal-NS"] = this.connection.namespace; + } + + if (this.connection.database) { + headers["Surreal-DB"] = this.connection.database; + } + + if (this.connection.token) { + headers.Authorization = `Bearer ${this.connection.token}`; } - const id = getIncrementalID(); const raw = await fetch(`${this.connection.url}`, { method: "POST", - headers: { - "Content-Type": "application/cbor", - Accept: "application/cbor", - "Surreal-NS": this.connection.namespace, - "Surreal-DB": this.connection.database, - ...(this.connection.token - ? { Authorization: `Bearer ${this.connection.token}` } - : {}), - }, + headers, body: this.encodeCbor({ id, ...request }), }); diff --git a/src/errors.ts b/src/errors.ts index 53e77b03..9abe316e 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -64,7 +64,7 @@ export class ConnectionUnavailable extends SurrealDbError { export class MissingNamespaceDatabase extends SurrealDbError { name = "MissingNamespaceDatabase"; - message = "There are no namespace and/or database configured."; + message = "There is no namespace and/or database selected."; } export class HttpConnectionError extends SurrealDbError { diff --git a/tests/integration/surreal.ts b/tests/integration/surreal.ts index c6bee0d7..69e43416 100644 --- a/tests/integration/surreal.ts +++ b/tests/integration/surreal.ts @@ -38,6 +38,7 @@ type CreateSurrealOptions = { protocol?: Protocol; auth?: PremadeAuth; reachable?: boolean; + unselected?: boolean; }; export async function setupServer(): Promise<{ @@ -61,13 +62,14 @@ export async function setupServer(): Promise<{ protocol, auth, reachable, + unselected, }: CreateSurrealOptions = {}) { protocol = protocol ? protocol : PROTOCOL; const surreal = new Surreal(); const port = reachable === false ? SURREAL_PORT_UNREACHABLE : SURREAL_PORT; await surreal.connect(`${protocol}://127.0.0.1:${port}/rpc`, { - namespace: SURREAL_NS, - database: SURREAL_DB, + namespace: unselected ? undefined : SURREAL_NS, + database: unselected ? undefined : SURREAL_DB, auth: createAuth(auth ?? "root"), }); diff --git a/tests/integration/tests/connection.test.ts b/tests/integration/tests/connection.test.ts index b3d38bd0..899e8717 100644 --- a/tests/integration/tests/connection.test.ts +++ b/tests/integration/tests/connection.test.ts @@ -25,3 +25,26 @@ describe("version check", async () => { expect(diff).toBeLessThanOrEqual(defaultVersionCheckTimeout + 100); // 100ms margin }); }); + +describe("rpc", async () => { + test("allowed rpcs without namespace or database", async () => { + const surreal = await createSurreal({ + unselected: true, + protocol: "http", + }); + + await surreal.version(); + await surreal.invalidate(); + }); + + test("disallowed rpcs without namespace or database", async () => { + const surreal = await createSurreal({ + unselected: true, + protocol: "http", + }); + + expect(async () => { + await surreal.query("SELECT * FROM test"); + }).toThrow(); + }); +});