Skip to content

Commit

Permalink
feat: improvements to "encryption" enhancement (#1927)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Dec 30, 2024
1 parent dcef942 commit 8486f64
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 21 deletions.
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type { PolicyDef } from './types';
/**
* All enhancement kinds
*/
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted'];
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encryption'];

/**
* Options for {@link createEnhancement}
Expand Down Expand Up @@ -129,7 +129,7 @@ export function createEnhancement<DbClient extends object>(
result = withPassword(result, options);
}

if (hasEncrypted && kinds.includes('encrypted')) {
if (hasEncrypted && kinds.includes('encryption')) {
if (!options.encryption) {
throw new Error('Encryption options are required for @encrypted enhancement');
}
Expand Down
21 changes: 13 additions & 8 deletions packages/runtime/src/enhancements/node/encrypted.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import {
resolveField,
type PrismaWriteActionType,
} from '../../cross';
import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types';
import { CustomEncryption, DbClientContract, SimpleEncryption } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { Logger } from './logger';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

Expand All @@ -27,28 +28,32 @@ export function withEncrypted<DbClient extends object = any>(
prisma,
options.modelMeta,
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
'encrypted'
'encryption'
);
}

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;
private encoder = new TextEncoder();
private decoder = new TextDecoder();
private logger: Logger;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);
this.logger = new Logger(prisma);

if (!options.encryption) throw new Error('Encryption options must be provided');
if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided');

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
throw new Error('Custom encryption must provide encrypt and decrypt functions');
throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions');
} else {
if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes');
if (!options.encryption.encryptionKey)
throw this.queryUtils.unknownError('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32)
throw this.queryUtils.unknownError('Encryption key must be 32 bytes');
}
}

Expand Down Expand Up @@ -147,7 +152,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
try {
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
} catch (error) {
console.warn('Decryption failed, keeping original value:', error);
this.logger.warn(`Decryption failed, keeping original value: ${error}`);
}
}
}
Expand All @@ -164,7 +169,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
try {
context.parent[field.name] = await this.encrypt(field, data);
} catch (error) {
throw new Error(`Encryption failed for field ${field.name}: ${error}`);
this.queryUtils.unknownError(`Encryption failed for field ${field.name}: ${error}`);
}
}
},
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = {
/**
* Kinds of enhancements to `PrismaClient`
*/
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted';
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encryption';

/**
* Function for transforming errors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {
isRelationshipField,
resolved,
} from '@zenstackhq/sdk';
import { ValidationAcceptor, streamAst } from 'langium';
import { ValidationAcceptor, streamAllContents, streamAst } from 'langium';
import pluralize from 'pluralize';
import { AstValidator } from '../types';
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';
Expand Down Expand Up @@ -138,6 +138,9 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
return;
}
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
}

@check('@allow')
Expand Down Expand Up @@ -166,6 +169,9 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
);
}
}

// @encrypted fields cannot be used in policy rules
this.rejectEncryptedFields(attr, accept);
}

@check('@@validate')
Expand Down Expand Up @@ -206,6 +212,14 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

private rejectEncryptedFields(attr: AttributeApplication, accept: ValidationAcceptor) {
streamAllContents(attr).forEach((node) => {
if (isDataModelFieldReference(node) && hasAttribute(node.target.ref as DataModelField, '@encrypted')) {
accept('error', `Encrypted fields cannot be used in policy rules`, { node });
}
});
}

private validatePolicyKinds(
kind: string,
candidates: string[],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { FieldInfo } from '@zenstackhq/runtime';
import { loadSchema } from '@zenstackhq/testtools';
import { loadSchema, loadModelWithError } from '@zenstackhq/testtools';
import path from 'path';

describe('Encrypted test', () => {
let origDir: string;
const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64'));

beforeAll(async () => {
origDir = path.resolve('.');
Expand All @@ -14,21 +15,25 @@ describe('Encrypted test', () => {
});

it('Simple encryption test', async () => {
const { enhance } = await loadSchema(`
const { enhance, prisma } = await loadSchema(
`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
@@allow('all', true)
}`);
}`,
{
enhancements: ['encryption'],
enhanceOptions: {
encryption: { encryptionKey },
},
}
);

const sudoDb = enhance(undefined, { kinds: [] });
const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64'));

const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: { encryptionKey },
});
const db = enhance();

const create = await db.user.create({
data: {
Expand All @@ -49,9 +54,50 @@ describe('Encrypted test', () => {
},
});

const rawRead = await prisma.user.findUnique({ where: { id: '1' } });

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).not.toBe('abc123');
expect(rawRead.encrypted_value).not.toBe('abc123');
});

it('Multi-field encryption test', async () => {
const { enhance } = await loadSchema(
`
model User {
id String @id @default(cuid())
x1 String @encrypted()
x2 String @encrypted()
@@allow('all', true)
}`,
{
enhancements: ['encryption'],
enhanceOptions: {
encryption: { encryptionKey },
},
}
);

const db = enhance();

const create = await db.user.create({
data: {
id: '1',
x1: 'abc123',
x2: '123abc',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

expect(create).toMatchObject({ x1: 'abc123', x2: '123abc' });
expect(read).toMatchObject({ x1: 'abc123', x2: '123abc' });
});

it('Custom encryption test', async () => {
Expand All @@ -65,7 +111,7 @@ describe('Encrypted test', () => {

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encrypted'],
kinds: ['encryption'],
encryption: {
encrypt: async (model: string, field: FieldInfo, data: string) => {
// Add _enc to the end of the input
Expand Down Expand Up @@ -105,4 +151,104 @@ describe('Encrypted test', () => {
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
});

it('Only supports string fields', async () => {
await expect(
loadModelWithError(
`
model User {
id String @id @default(cuid())
encrypted_value Bytes @encrypted()
}`
)
).resolves.toContain(`attribute \"@encrypted\" cannot be used on this type of field`);
});

it('Returns cipher text when decryption fails', async () => {
const { enhance, enhanceRaw, prisma } = await loadSchema(
`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
@@allow('all', true)
}`,
{ enhancements: ['encryption'] }
);

const db = enhance(undefined, {
kinds: ['encryption'],
encryption: { encryptionKey },
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});
expect(create.encrypted_value).toBe('abc123');

const db1 = enhanceRaw(prisma, undefined, {
encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) },
});
const read = await db1.user.findUnique({ where: { id: '1' } });
expect(read.encrypted_value).toBeTruthy();
expect(read.encrypted_value).not.toBe('abc123');
});

it('Works with length validation', async () => {
const { enhance } = await loadSchema(
`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted() @length(0, 6)
@@allow('all', true)
}`,
{
enhanceOptions: { encryption: { encryptionKey } },
}
);

const db = enhance();

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});
expect(create.encrypted_value).toBe('abc123');

await expect(
db.user.create({
data: { id: '2', encrypted_value: 'abc1234' },
})
).toBeRejectedByPolicy();
});

it('Complains when encrypted fields are used in model-level policy rules', async () => {
await expect(
loadModelWithError(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
@@allow('all', encrypted_value != 'abc123')
}
`)
).resolves.toContain(`Encrypted fields cannot be used in policy rules`);
});

it('Complains when encrypted fields are used in field-level policy rules', async () => {
await expect(
loadModelWithError(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
value Int @allow('all', encrypted_value != 'abc123')
}
`)
).resolves.toContain(`Encrypted fields cannot be used in policy rules`);
});
});

0 comments on commit 8486f64

Please sign in to comment.