From 8486f6447bfb89279118998bd95a7156c244cb72 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 30 Dec 2024 15:22:48 +0800 Subject: [PATCH] feat: improvements to "encryption" enhancement (#1927) --- .../enhancements/node/create-enhancement.ts | 4 +- .../src/enhancements/node/encrypted.ts | 21 ++- packages/runtime/src/types.ts | 2 +- .../attribute-application-validator.ts | 16 +- .../with-encrypted/with-encrypted.test.ts | 164 +++++++++++++++++- 5 files changed, 186 insertions(+), 21 deletions(-) diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index 871f8a1b4..6090f523f 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -21,7 +21,7 @@ import type { PolicyDef } from './types'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encryption']; /** * Options for {@link createEnhancement} @@ -129,7 +129,7 @@ export function createEnhancement( result = withPassword(result, options); } - if (hasEncrypted && kinds.includes('encrypted')) { + if (hasEncrypted && kinds.includes('encryption')) { if (!options.encryption) { throw new Error('Encryption options are required for @encrypted enhancement'); } diff --git a/packages/runtime/src/enhancements/node/encrypted.ts b/packages/runtime/src/enhancements/node/encrypted.ts index c6d6fc873..d5db66690 100644 --- a/packages/runtime/src/enhancements/node/encrypted.ts +++ b/packages/runtime/src/enhancements/node/encrypted.ts @@ -9,8 +9,9 @@ import { resolveField, type PrismaWriteActionType, } from '../../cross'; -import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types'; +import { CustomEncryption, DbClientContract, SimpleEncryption } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; +import { Logger } from './logger'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; import { QueryUtils } from './query-utils'; @@ -27,7 +28,7 @@ export function withEncrypted( prisma, options.modelMeta, (_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options), - 'encrypted' + 'encryption' ); } @@ -35,20 +36,24 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { private queryUtils: QueryUtils; private encoder = new TextEncoder(); private decoder = new TextDecoder(); + private logger: Logger; constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); this.queryUtils = new QueryUtils(prisma, options); + this.logger = new Logger(prisma); - if (!options.encryption) throw new Error('Encryption options must be provided'); + if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided'); if (this.isCustomEncryption(options.encryption!)) { if (!options.encryption.encrypt || !options.encryption.decrypt) - throw new Error('Custom encryption must provide encrypt and decrypt functions'); + throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions'); } else { - if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided'); - if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes'); + if (!options.encryption.encryptionKey) + throw this.queryUtils.unknownError('Encryption key must be provided'); + if (options.encryption.encryptionKey.length !== 32) + throw this.queryUtils.unknownError('Encryption key must be 32 bytes'); } } @@ -147,7 +152,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { try { entityData[field] = await this.decrypt(fieldInfo, entityData[field]); } catch (error) { - console.warn('Decryption failed, keeping original value:', error); + this.logger.warn(`Decryption failed, keeping original value: ${error}`); } } } @@ -164,7 +169,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { try { context.parent[field.name] = await this.encrypt(field, data); } catch (error) { - throw new Error(`Encryption failed for field ${field.name}: ${error}`); + this.queryUtils.unknownError(`Encryption failed for field ${field.name}: ${error}`); } } }, diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index e691fc32c..012c94699 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -151,7 +151,7 @@ export type EnhancementContext = { /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encryption'; /** * Function for transforming errors. diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index a7c0fef9a..0e1d8e885 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -25,7 +25,7 @@ import { isRelationshipField, resolved, } from '@zenstackhq/sdk'; -import { ValidationAcceptor, streamAst } from 'langium'; +import { ValidationAcceptor, streamAllContents, streamAst } from 'langium'; import pluralize from 'pluralize'; import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; @@ -138,6 +138,9 @@ export default class AttributeApplicationValidator implements AstValidator { + if (isDataModelFieldReference(node) && hasAttribute(node.target.ref as DataModelField, '@encrypted')) { + accept('error', `Encrypted fields cannot be used in policy rules`, { node }); + } + }); + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 1e0544c0b..9b6307822 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -1,9 +1,10 @@ import { FieldInfo } from '@zenstackhq/runtime'; -import { loadSchema } from '@zenstackhq/testtools'; +import { loadSchema, loadModelWithError } from '@zenstackhq/testtools'; import path from 'path'; describe('Encrypted test', () => { let origDir: string; + const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); beforeAll(async () => { origDir = path.resolve('.'); @@ -14,21 +15,25 @@ describe('Encrypted test', () => { }); it('Simple encryption test', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema( + ` model User { id String @id @default(cuid()) encrypted_value String @encrypted() @@allow('all', true) - }`); + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); const sudoDb = enhance(undefined, { kinds: [] }); - const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); - const db = enhance(undefined, { - kinds: ['encrypted'], - encryption: { encryptionKey }, - }); + const db = enhance(); const create = await db.user.create({ data: { @@ -49,9 +54,50 @@ describe('Encrypted test', () => { }, }); + const rawRead = await prisma.user.findUnique({ where: { id: '1' } }); + expect(create.encrypted_value).toBe('abc123'); expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).not.toBe('abc123'); + expect(rawRead.encrypted_value).not.toBe('abc123'); + }); + + it('Multi-field encryption test', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + x1 String @encrypted() + x2 String @encrypted() + + @@allow('all', true) + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + x1: 'abc123', + x2: '123abc', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create).toMatchObject({ x1: 'abc123', x2: '123abc' }); + expect(read).toMatchObject({ x1: 'abc123', x2: '123abc' }); }); it('Custom encryption test', async () => { @@ -65,7 +111,7 @@ describe('Encrypted test', () => { const sudoDb = enhance(undefined, { kinds: [] }); const db = enhance(undefined, { - kinds: ['encrypted'], + kinds: ['encryption'], encryption: { encrypt: async (model: string, field: FieldInfo, data: string) => { // Add _enc to the end of the input @@ -105,4 +151,104 @@ describe('Encrypted test', () => { expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).toBe('abc123_enc'); }); + + it('Only supports string fields', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id @default(cuid()) + encrypted_value Bytes @encrypted() + }` + ) + ).resolves.toContain(`attribute \"@encrypted\" cannot be used on this type of field`); + }); + + it('Returns cipher text when decryption fails', async () => { + const { enhance, enhanceRaw, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`, + { enhancements: ['encryption'] } + ); + + const db = enhance(undefined, { + kinds: ['encryption'], + encryption: { encryptionKey }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + const db1 = enhanceRaw(prisma, undefined, { + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) }, + }); + const read = await db1.user.findUnique({ where: { id: '1' } }); + expect(read.encrypted_value).toBeTruthy(); + expect(read.encrypted_value).not.toBe('abc123'); + }); + + it('Works with length validation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() @length(0, 6) + + @@allow('all', true) + }`, + { + enhanceOptions: { encryption: { encryptionKey } }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + await expect( + db.user.create({ + data: { id: '2', encrypted_value: 'abc1234' }, + }) + ).toBeRejectedByPolicy(); + }); + + it('Complains when encrypted fields are used in model-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + @@allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); + + it('Complains when encrypted fields are used in field-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + value Int @allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); });