Skip to content

Commit

Permalink
feat: add support for comparing fields in the same model (zenstackhq#631
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ymc9 authored Aug 15, 2023
1 parent 4bc72a8 commit 4776685
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 61 deletions.
44 changes: 22 additions & 22 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

args = this.utils.clone(args);
if (!(await this.utils.injectForRead(this.model, args))) {
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return null;
}

Expand All @@ -86,7 +86,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

args = this.utils.clone(args);
if (!(await this.utils.injectForRead(this.model, args))) {
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
throw this.utils.notFound(this.model);
}

Expand All @@ -100,7 +100,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async findFirst(args: any) {
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.model, args))) {
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return null;
}

Expand All @@ -114,7 +114,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async findFirstOrThrow(args: any) {
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.model, args))) {
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
throw this.utils.notFound(this.model);
}

Expand All @@ -128,7 +128,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async findMany(args: any) {
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.model, args))) {
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return [];
}

Expand All @@ -152,7 +152,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

await this.utils.tryReject(this.model, 'create');
await this.utils.tryReject(this.prisma, this.model, 'create');

const origArgs = args;
args = this.utils.clone(args);
Expand Down Expand Up @@ -404,7 +404,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

this.utils.tryReject(this.model, 'create');
this.utils.tryReject(this.prisma, this.model, 'create');

args = this.utils.clone(args);

Expand Down Expand Up @@ -635,7 +635,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

if (thisModelUpdate) {
this.utils.tryReject(this.model, 'update');
this.utils.tryReject(db, this.model, 'update');

// check pre-update guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db);
Expand All @@ -660,7 +660,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

updateMany: async (model, args, context) => {
// injects auth guard into where clause
await this.utils.injectAuthGuard(args, model, 'update');
await this.utils.injectAuthGuard(db, args, model, 'update');

// prepare for post-update check
if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) {
Expand All @@ -671,7 +671,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
const reversedQuery = await this.utils.buildReversedQuery(context);
const currentSetQuery = { select, where: reversedQuery };
await this.utils.injectAuthGuard(currentSetQuery, model, 'read');
await this.utils.injectAuthGuard(db, currentSetQuery, model, 'read');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(currentSetQuery)}`);
Expand Down Expand Up @@ -794,7 +794,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

deleteMany: async (model, args, context) => {
// inject delete guard
const guard = await this.utils.getAuthGuard(model, 'delete');
const guard = await this.utils.getAuthGuard(db, model, 'delete');
context.parent.deleteMany = this.utils.and(args, guard);
},
});
Expand Down Expand Up @@ -822,10 +822,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

await this.utils.tryReject(this.model, 'update');
await this.utils.tryReject(this.prisma, this.model, 'update');

args = this.utils.clone(args);
await this.utils.injectAuthGuard(args, this.model, 'update');
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'update');

if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) {
// use a transaction to do post-update checks
Expand All @@ -838,7 +838,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
select = { ...select, ...preValueSelect };
}
const currentSetQuery = { select, where: args.where };
await this.utils.injectAuthGuard(currentSetQuery, this.model, 'read');
await this.utils.injectAuthGuard(tx, currentSetQuery, this.model, 'read');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
Expand Down Expand Up @@ -885,8 +885,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'update field is required in query argument');
}

await this.utils.tryReject(this.model, 'create');
await this.utils.tryReject(this.model, 'update');
await this.utils.tryReject(this.prisma, this.model, 'create');
await this.utils.tryReject(this.prisma, this.model, 'update');

// We can call the native "upsert" because we can't tell if an entity was created or updated
// for doing post-write check accordingly. Instead, decompose it into create or update.
Expand Down Expand Up @@ -930,7 +930,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}

await this.utils.tryReject(this.model, 'delete');
await this.utils.tryReject(this.prisma, this.model, 'delete');

const { result, error } = await this.transaction(async (tx) => {
// do a read-back before delete
Expand Down Expand Up @@ -961,11 +961,11 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

async deleteMany(args: any) {
await this.utils.tryReject(this.model, 'delete');
await this.utils.tryReject(this.prisma, this.model, 'delete');

// inject policy conditions
args = args ?? {};
await this.utils.injectAuthGuard(args, this.model, 'delete');
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'delete');

// conduct the deletion
if (this.shouldLogQuery) {
Expand All @@ -984,7 +984,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

// inject policy conditions
await this.utils.injectAuthGuard(args, this.model, 'read');
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`);
Expand All @@ -998,7 +998,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

// inject policy conditions
await this.utils.injectAuthGuard(args, this.model, 'read');
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`);
Expand All @@ -1009,7 +1009,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
async count(args: any) {
// inject policy conditions
args = args ?? {};
await this.utils.injectAuthGuard(args, this.model, 'read');
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`);
Expand Down
Loading

0 comments on commit 4776685

Please sign in to comment.