From ef7d395003c39bce512dbeb10adc6cd0c19fbc26 Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Wed, 25 Jan 2023 09:38:20 +0000 Subject: [PATCH] fix: Increase test coverage and improve setup safety This change introduces the use of advisory locks on startup to ensure safety if multiple clients are bootstrapping simultaneously. Without the lock you run the risk of concurrent updates to the same tables (mostly when using `GRANT`). Signed-off-by: Lucian Buzzo --- .../migration.sql | 25 +++ prisma/schema.prisma | 9 +- src/index.ts | 78 ++++---- test/integration/index.spec.ts | 29 ++- test/integration/middleware.spec.ts | 180 +++++++++--------- test/integration/rbac.spec.ts | 148 +++++++++++--- 6 files changed, 320 insertions(+), 149 deletions(-) create mode 100644 prisma/migrations/20230125135410_add_tags_model/migration.sql diff --git a/prisma/migrations/20230125135410_add_tags_model/migration.sql b/prisma/migrations/20230125135410_add_tags_model/migration.sql new file mode 100644 index 0000000..6f08252 --- /dev/null +++ b/prisma/migrations/20230125135410_add_tags_model/migration.sql @@ -0,0 +1,25 @@ +-- CreateTable +CREATE TABLE "Tag" ( + "id" SERIAL NOT NULL, + "label" TEXT NOT NULL, + + CONSTRAINT "Tag_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "_PostToTag" ( + "A" INTEGER NOT NULL, + "B" INTEGER NOT NULL +); + +-- CreateIndex +CREATE UNIQUE INDEX "_PostToTag_AB_unique" ON "_PostToTag"("A", "B"); + +-- CreateIndex +CREATE INDEX "_PostToTag_B_index" ON "_PostToTag"("B"); + +-- AddForeignKey +ALTER TABLE "_PostToTag" ADD CONSTRAINT "_PostToTag_A_fkey" FOREIGN KEY ("A") REFERENCES "Post"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "_PostToTag" ADD CONSTRAINT "_PostToTag_B_fkey" FOREIGN KEY ("B") REFERENCES "Tag"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index cb18893..203c55d 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -26,9 +26,16 @@ model Post { title String @db.VarChar(255) author User? @relation(fields: [authorId], references: [id]) authorId Int? + tags Tag[] +} + +model Tag { + id Int @id @default(autoincrement()) + label String + posts Post[] } enum Role { USER ADMIN -} \ No newline at end of file +} diff --git a/src/index.ts b/src/index.ts index 1237418..a4b17f2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,16 +15,25 @@ export type ModelAbilities = { [Model in Models]: { [op: string]: Ability } }; export type GetContextFn = () => { role: string; - context: { + context?: { [key: string]: string; }; } | null; +/** + * This function is used to take a lock that is automatically released at the end of the current transaction. + * This is very convenient for ensuring we don't hit concurrency issues when running setup code. + */ +const takeLock = (prisma: PrismaClient) => + prisma.$executeRawUnsafe("SELECT pg_advisory_xact_lock(2142616474639426746);"); + export const createAbilityName = (model: string, ability: string) => { return `${model}_${ability}_role`.toLowerCase(); }; export const createRoleName = (name: string) => { + // Esnure the role name only has lowercase alpha characters and underscores + // This also doubles as a check against SQL injection const normalized = name.toLowerCase().replace("-", "_").replace(/[^a-z_]/g, ""); return `yates_role_${normalized}`; }; @@ -52,12 +61,6 @@ export const setupMiddleware = (prisma: PrismaClient, getContext: GetContextFn) const pgRole = createRoleName(role); - // Check the role name only has lowercase alpha characters and underscores - // This also doubles as a check against SQL injection - if (pgRole.match(/[^a-z_]/)) { - throw new Error("Invalid role name."); - } - // Generate model class name from model params (PascalCase to camelCase) const modelName = params.model.charAt(0).toLowerCase() + params.model.slice(1); @@ -208,25 +211,33 @@ export const createRoles = async ({ for (const model in abilities) { const table = model; - await prisma.$queryRawUnsafe(`ALTER table "${table}" enable row level security`); + await prisma.$transaction([ + takeLock(prisma), + prisma.$queryRawUnsafe(`ALTER table "${table}" enable row level security;`), + ]); for (const slug in abilities[model as keyof typeof abilities]) { const ability = abilities[model as keyof typeof abilities]![slug]; const roleName = createAbilityName(model, slug); // Check if role already exists - await prisma.$queryRawUnsafe(` - do - $$ - begin - if not exists (select * from pg_catalog.pg_roles where rolname = '${roleName}') then - create role ${roleName}; + await prisma.$transaction([ + takeLock(prisma), + prisma.$queryRawUnsafe(` + do + $$ + begin + if not exists (select * from pg_catalog.pg_roles where rolname = '${roleName}') then + create role ${roleName}; + end if; + end + $$ + ; + `), + prisma.$queryRawUnsafe(` GRANT ${ability.operation} ON "${table}" TO ${roleName}; - end if; - end - $$ - ; - `); + `), + ]); if (ability.expression) { await setRLS(prisma, table, roleName, ability.operation, ability.expression); @@ -239,7 +250,7 @@ export const createRoles = async ({ // It's not possible to dynamically GRANT these to a shared user role, as the GRANT is not isolated per transaction and leads to broken permissions. for (const key in roles) { const role = createRoleName(key); - await prisma.$queryRawUnsafe(` + await prisma.$executeRawUnsafe(` do $$ begin @@ -251,18 +262,6 @@ export const createRoles = async ({ ; `); - // Note: We need to GRANT all on schema public so that we can resolve relation queries with prisma, as they will sometimes use a join table. - // This is not ideal, but because we are using RLS, it's not a security risk. Any table with RLS also needs a corresponding policy for the role to have access. - await prisma.$queryRawUnsafe(` - GRANT ALL ON ALL TABLES IN SCHEMA public TO ${role}; - `); - await prisma.$queryRawUnsafe(` - GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO ${role}; - `); - await prisma.$queryRawUnsafe(` - GRANT ALL ON SCHEMA public TO ${role}; - `); - const wildCardAbilities = flatMap(abilities, (model, modelName) => { return map(model, (params, slug) => { return createAbilityName(modelName, slug); @@ -273,7 +272,20 @@ export const createRoles = async ({ roleAbilities === "*" ? wildCardAbilities : roleAbilities.map((ability) => createAbilityName(ability.model!, ability.slug!)); - await prisma.$queryRawUnsafe(`GRANT ${rlsRoles.join(", ")} TO ${role}`); + + // Note: We need to GRANT all on schema public so that we can resolve relation queries with prisma, as they will sometimes use a join table. + // This is not ideal, but because we are using RLS, it's not a security risk. Any table with RLS also needs a corresponding policy for the role to have access. + await prisma.$transaction([ + takeLock(prisma), + prisma.$executeRawUnsafe(`GRANT ALL ON ALL TABLES IN SCHEMA public TO ${role};`), + prisma.$executeRawUnsafe(` + GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO ${role}; + `), + prisma.$executeRawUnsafe(` + GRANT ALL ON SCHEMA public TO ${role}; + `), + prisma.$queryRawUnsafe(`GRANT ${rlsRoles.join(", ")} TO ${role}`), + ]); } }; diff --git a/test/integration/index.spec.ts b/test/integration/index.spec.ts index 70bdd28..2679811 100644 --- a/test/integration/index.spec.ts +++ b/test/integration/index.spec.ts @@ -22,13 +22,40 @@ describe("setup", () => { expect(getRoles.mock.calls).toHaveLength(1); const abilities = getRoles.mock.calls[0][0]; - expect(Object.keys(abilities)).toStrictEqual(["User", "Post"]); + expect(Object.keys(abilities)).toStrictEqual(["User", "Post", "Tag"]); expect(Object.keys(abilities.User)).toStrictEqual(["create", "read", "update", "delete"]); expect(Object.keys(abilities.Post)).toStrictEqual(["create", "read", "update", "delete"]); + expect(Object.keys(abilities.Tag)).toStrictEqual(["create", "read", "update", "delete"]); }); }); describe("params.getContext()", () => { + it("should skip RBAC if .getContext() returns null", async () => { + const prisma = new PrismaClient(); + + const role = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { + return { + [role]: [abilities.Post.read], + }; + }, + getContext: () => { + return null; + }, + }); + + const post = await prisma.post.create({ + data: { + title: "Test post", + }, + }); + + expect(post.id).toBeDefined(); + }); + it("should allow a custom context to be set", async () => { const prisma = new PrismaClient(); diff --git a/test/integration/middleware.spec.ts b/test/integration/middleware.spec.ts index f186c2d..5fe0d6e 100644 --- a/test/integration/middleware.spec.ts +++ b/test/integration/middleware.spec.ts @@ -6,98 +6,96 @@ import { setup } from "../../src"; let adminClient: PrismaClient; beforeAll(async () => { - adminClient = new PrismaClient(); + adminClient = new PrismaClient(); }); describe("middlewares", () => { - it("should not run twice", async () => { - const prisma = new PrismaClient(); - - const middlewareSpy = jest.fn(async (params, next) => { - return next(params); - }); - - prisma.$use(middlewareSpy); - - const role = `USER_${uuid()}`; - - await setup({ - prisma, - getRoles(abilities) { - return { - [role]: [abilities.Post.read, abilities.Post.create], - }; - }, - getContext: () => { - return { - role, - context: {}, - }; - }, - }); - - middlewareSpy.mockClear(); - - const post = await prisma.post.create({ - data: { - title: `Test post from ${role}`, - }, - }); - - expect(post.id).toBeDefined(); - expect(middlewareSpy).toHaveBeenCalledTimes(1); - }); - - it("should not be able to bypass RBAC when using cls-hooked", async () => { - const prisma = new PrismaClient(); - - const middleware: Prisma.Middleware = async (params, next) => { - if (params.model === "Post") { - const post = await next(params); - return post; - } else { - return next(params); - } - }; - - const clsSession = createNamespace("test"); - - prisma.$use(middleware); - - const roleName = `USER_${uuid()}`; - - await setup({ - prisma, - getRoles(abilities) { - return { - [roleName]: [abilities.Post.read], - }; - }, - getContext: () => { - const role = clsSession.get("role"); - return { - role, - context: {}, - }; - }, - }); - - await expect( - new Promise((res, reject) => { - clsSession.run(async () => { - try { - clsSession.set("role", roleName); - const result = await prisma.post.create({ - data: { - title: `Test post from ${roleName}`, - }, - }); - res(result); - } catch (e) { - reject(e); - } - }); - }), - ).rejects.toThrow(); - }); + it("should not run twice", async () => { + const prisma = new PrismaClient(); + + const middlewareSpy = jest.fn(async (params, next) => { + return next(params); + }); + + prisma.$use(middlewareSpy); + + const role = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { + return { + [role]: [abilities.Post.read, abilities.Post.create], + }; + }, + getContext: () => { + return { + role, + }; + }, + }); + + middlewareSpy.mockClear(); + + const post = await prisma.post.create({ + data: { + title: `Test post from ${role}`, + }, + }); + + expect(post.id).toBeDefined(); + expect(middlewareSpy).toHaveBeenCalledTimes(1); + }); + + it("should not be able to bypass RBAC when using cls-hooked", async () => { + const prisma = new PrismaClient(); + + const middleware: Prisma.Middleware = async (params, next) => { + if (params.model === "Post") { + const post = await next(params); + return post; + } else { + return next(params); + } + }; + + const clsSession = createNamespace("test"); + + prisma.$use(middleware); + + const roleName = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { + return { + [roleName]: [abilities.Post.read], + }; + }, + getContext: () => { + const role = clsSession.get("role"); + return { + role, + }; + }, + }); + + await expect( + new Promise((res, reject) => { + clsSession.run(async () => { + try { + clsSession.set("role", roleName); + const result = await prisma.post.create({ + data: { + title: `Test post from ${roleName}`, + }, + }); + res(result); + } catch (e) { + reject(e); + } + }); + }) + ).rejects.toThrow(); + }); }); diff --git a/test/integration/rbac.spec.ts b/test/integration/rbac.spec.ts index c9a66ec..bf49a60 100644 --- a/test/integration/rbac.spec.ts +++ b/test/integration/rbac.spec.ts @@ -1,4 +1,4 @@ -import { PrismaClient } from "@prisma/client"; +import { PrismaClient, User } from "@prisma/client"; import { setup } from "../../src"; import { v4 as uuid } from "uuid"; @@ -9,25 +9,52 @@ beforeAll(async () => { }); describe("rbac", () => { - describe("CREATE", () => { - it("should be able to allow a role to create a resource", async () => { + describe("raw", () => { + it("should skip RBAC when using prisma.$queryRaw()", async () => { const prisma = new PrismaClient(); + const user = await prisma.user.create({ + data: { + email: `test-${uuid()}@test.com`, + }, + }); + const role = `USER_${uuid()}`; await setup({ prisma, getRoles(abilities) { return { - [role]: [abilities.Post.read, abilities.Post.create], + [role]: [abilities.Post.read], }; }, - getContext: () => { + getContext: () => ({ + role, + }), + }); + + const users: User[] = await prisma.$queryRaw`SELECT * FROM "User" WHERE "id" = ${user.id}`; + + expect(users).toHaveLength(1); + expect(users[0].id).toBe(user.id); + }); + }); + describe("CREATE", () => { + it("should be able to allow a role to create a resource", async () => { + const prisma = new PrismaClient(); + + const role = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { return { - role, - context: {}, + [role]: [abilities.Post.read, abilities.Post.create], }; }, + getContext: () => ({ + role, + }), }); const post = await prisma.post.create({ @@ -51,12 +78,9 @@ describe("rbac", () => { [role]: [abilities.Post.read], }; }, - getContext: () => { - return { - role, - context: {}, - }; - }, + getContext: () => ({ + role, + }), }); await expect( @@ -82,12 +106,9 @@ describe("rbac", () => { [role]: [abilities.Post.read], }; }, - getContext: () => { - return { - role, - context: {}, - }; - }, + getContext: () => ({ + role, + }), }); const { id: postId } = await adminClient.post.create({ @@ -115,25 +136,106 @@ describe("rbac", () => { [role]: [abilities.User.read], }; }, - getContext: () => { + getContext: () => ({ + role, + }), + }); + + const { id: postId } = await adminClient.post.create({ + data: { + title: `Test post from ${role}`, + }, + }); + + const post = await prisma.post.findUnique({ + where: { id: postId }, + }); + + expect(post).toBeNull(); + }); + + it("should be able to allow a role to read a resource using 1:1 relation queries", async () => { + const prisma = new PrismaClient(); + + const role = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { + return { + [role]: [abilities.Post.read, abilities.User.read], + }; + }, + getContext: () => ({ + role, + }), + }); + + const { id: postId } = await adminClient.post.create({ + data: { + title: `Test post from ${role}`, + author: { + create: { + email: `test-${uuid()}@test.com`, + }, + }, + }, + }); + + const post = await prisma.post.findUnique({ + where: { id: postId }, + include: { + author: true, + }, + }); + + expect(post?.id).toBe(postId); + expect(post?.author).toBeDefined(); + }); + + it("should be able to allow a role to read a resource using many-to-many relation queries", async () => { + const prisma = new PrismaClient(); + + const role = `USER_${uuid()}`; + + await setup({ + prisma, + getRoles(abilities) { return { - role, - context: {}, + [role]: [abilities.Post.read, abilities.User.read, abilities.Tag.read], }; }, + getContext: () => ({ + role, + }), }); const { id: postId } = await adminClient.post.create({ data: { title: `Test post from ${role}`, + tags: { + create: { + label: "engineering", + }, + }, }, }); const post = await prisma.post.findUnique({ where: { id: postId }, + select: { + id: true, + title: true, + tags: { + select: { + label: true, + }, + }, + }, }); - expect(post).toBeNull(); + expect(post?.id).toBe(postId); + expect(post?.tags).toBeDefined(); }); }); });