Skip to content

Commit

Permalink
fix: fix policy generation for collection predicate expressions (zens…
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Sep 21, 2023
1 parent 2d41a9f commit b8a875e
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
isReferenceExpr,
} from '@zenstackhq/language/ast';
import { isFutureExpr, resolved } from '@zenstackhq/sdk';
import { ValidationAcceptor, streamAllContents } from 'langium';
import { ValidationAcceptor, streamAst } from 'langium';
import pluralize from 'pluralize';
import { AstValidator } from '../types';
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';
Expand Down Expand Up @@ -134,7 +134,7 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
this.validatePolicyKinds(kind, ['read', 'update', 'all'], attr, accept);

const expr = attr.args[1].value;
if ([expr, ...streamAllContents(expr)].some((node) => isFutureExpr(node))) {
if (streamAst(expr).some((node) => isFutureExpr(node))) {
accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr });
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ import {
BinaryExpr,
Expression,
ExpressionType,
isBinaryExpr,
isDataModel,
isEnum,
isNullExpr,
isThisExpr,
} from '@zenstackhq/language/ast';
import { isDataModelFieldReference } from '@zenstackhq/sdk';
import { ValidationAcceptor } from 'langium';
import { isAuthInvocation } from '../../utils/ast-utils';
import { isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils';
import { AstValidator } from '../types';

/**
Expand All @@ -23,7 +22,7 @@ export default class ExpressionValidator implements AstValidator<Expression> {
if (isAuthInvocation(expr)) {
// check was done at link time
accept('error', 'auth() cannot be resolved because no "User" model is defined', { node: expr });
} else if (this.isCollectionPredicate(expr)) {
} else if (isCollectionPredicate(expr)) {
accept('error', 'collection predicate can only be used on an array of model type', { node: expr });
} else {
accept('error', 'expression cannot be resolved', {
Expand Down Expand Up @@ -142,8 +141,4 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}
}
}

private isCollectionPredicate(expr: Expression) {
return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator);
}
}
79 changes: 45 additions & 34 deletions packages/schema/src/plugins/access-policy/policy-guard-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
DataModelFieldAttribute,
Enum,
Expression,
MemberAccessExpr,
Model,
isBinaryExpr,
isDataModel,
Expand Down Expand Up @@ -49,12 +48,12 @@ import {
resolved,
saveProject,
} from '@zenstackhq/sdk';
import { streamAllContents } from 'langium';
import { streamAllContents, streamAst, streamContents } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import { FunctionDeclaration, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph';
import { name } from '.';
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils';
import {
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
Expand Down Expand Up @@ -237,7 +236,7 @@ export default class PolicyGenerator {
}

private hasFutureReference(expr: Expression) {
for (const node of this.allNodes(expr)) {
for (const node of streamAst(expr)) {
if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) {
return true;
}
Expand Down Expand Up @@ -434,7 +433,7 @@ export default class PolicyGenerator {

private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) {
return [...allows, ...denies].every((rule) => {
return [...this.allNodes(rule)].every((expr) => {
return streamAst(rule).every((expr) => {
if (isThisExpr(expr)) {
return false;
}
Expand Down Expand Up @@ -487,6 +486,8 @@ export default class PolicyGenerator {
});
};

// visit a reference or member access expression to build a
// selection path
const visit = (node: Expression): string[] | undefined => {
if (isReferenceExpr(node)) {
const target = resolved(node.target);
Expand All @@ -509,35 +510,50 @@ export default class PolicyGenerator {
return undefined;
};

for (const rule of [...allows, ...denies]) {
for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) {
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
// a standalone `this` expression, include all id fields
const model = expr.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
idFields.forEach((field) => addPath([field.name]));
continue;
}

// only care about member access and reference expressions
if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) {
continue;
}

if (expr.$container.$type === MemberAccessExpr) {
// only visit top-level member access
continue;
}
// collect selection paths from the given expression
const collectReferencePaths = (expr: Expression): string[][] => {
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
// a standalone `this` expression, include all id fields
const model = expr.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
return idFields.map((field) => [field.name]);
}

if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) {
const path = visit(expr);
if (path) {
if (isDataModel(expr.$resolvedType?.decl)) {
// member selection ended at a data model field, include its 'id'
path.push('id');
// member selection ended at a data model field, include its id fields
const idFields = getIdFields(expr.$resolvedType?.decl as DataModel);
return idFields.map((field) => [...path, field.name]);
} else {
return [path];
}
addPath(path);
} else {
return [];
}
} else if (isCollectionPredicate(expr)) {
const path = visit(expr.left);
if (path) {
// recurse into RHS
const rhs = collectReferencePaths(expr.right);
// combine path of LHS and RHS
return rhs.map((r) => [...path, ...r]);
} else {
return [];
}
} else {
// recurse
const children = streamContents(expr)
.filter((child): child is Expression => isExpression(child))
.toArray();
return children.flatMap((child) => collectReferencePaths(child));
}
};

for (const rule of [...allows, ...denies]) {
const paths = collectReferencePaths(rule);
paths.forEach((p) => addPath(p));
}

return Object.keys(result).length === 0 ? undefined : result;
Expand All @@ -556,7 +572,7 @@ export default class PolicyGenerator {
this.generateNormalizedAuthRef(model, allows, denies, statements);

const hasFieldAccess = [...denies, ...allows].some((rule) =>
[...this.allNodes(rule)].some(
streamAst(rule).some(
(child) =>
// this.???
isThisExpr(child) ||
Expand Down Expand Up @@ -724,7 +740,7 @@ export default class PolicyGenerator {
) {
// check if any allow or deny rule contains 'auth()' invocation
const hasAuthRef = [...allows, ...denies].some((rule) =>
[...this.allNodes(rule)].some((child) => isAuthInvocation(child))
streamAst(rule).some((child) => isAuthInvocation(child))
);

if (hasAuthRef) {
Expand All @@ -747,9 +763,4 @@ export default class PolicyGenerator {
);
}
}

private *allNodes(expr: Expression) {
yield expr;
yield* streamAllContents(expr);
}
}
6 changes: 6 additions & 0 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import {
BinaryExpr,
DataModel,
DataModelField,
Expression,
isArrayExpr,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
Expand Down Expand Up @@ -150,3 +152,7 @@ export function getAllDeclarationsFromImports(documents: LangiumDocuments, model
const imports = resolveTransitiveImports(documents, model);
return model.declarations.concat(...imports.map((imp) => imp.declarations));
}

export function isCollectionPredicate(expr: Expression): expr is BinaryExpr {
return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator);
}
81 changes: 53 additions & 28 deletions packages/schema/src/utils/typescript-expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
UnaryExpr,
} from '@zenstackhq/language/ast';
import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk';
import { match, P } from 'ts-pattern';
import { getIdFields } from './ast-utils';

export class TypeScriptExpressionTransformerError extends Error {
Expand Down Expand Up @@ -53,7 +54,7 @@ export class TypeScriptExpressionTransformer {
*
* @param isPostGuard indicates if we're writing for post-update conditions
*/
constructor(private readonly options?: Options) {}
constructor(private readonly options: Options) {}

/**
* Transforms the given expression to a TypeScript expression.
Expand Down Expand Up @@ -302,33 +303,57 @@ export class TypeScriptExpressionTransformer {
}

private binary(expr: BinaryExpr, normalizeUndefined: boolean): string {
if (expr.operator === 'in') {
return `(${this.transform(expr.right, false)}?.includes(${this.transform(
expr.left,
normalizeUndefined
)}) ?? false)`;
} else if (
(expr.operator === '==' || expr.operator === '!=') &&
(isThisExpr(expr.left) || isThisExpr(expr.right))
) {
// map equality comparison with `this` to id comparison
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
const model = _this.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
if (!idFields || idFields.length === 0) {
throw new TypeScriptExpressionTransformerError(`model "${model.name}" does not have an id field`);
}
let result = `allFieldsEqual(${this.transform(expr.left, false)},
const _default = `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
expr.right,
normalizeUndefined
)})`;

return match(expr.operator)
.with(
'in',
() =>
`(${this.transform(expr.right, false)}?.includes(${this.transform(
expr.left,
normalizeUndefined
)}) ?? false)`
)
.with(P.union('==', '!='), () => {
if (isThisExpr(expr.left) || isThisExpr(expr.right)) {
// map equality comparison with `this` to id comparison
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
const model = _this.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
if (!idFields || idFields.length === 0) {
throw new TypeScriptExpressionTransformerError(
`model "${model.name}" does not have an id field`
);
}
let result = `allFieldsEqual(${this.transform(expr.left, false)},
${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`;
if (expr.operator === '!=') {
result = `!${result}`;
}
return result;
} else {
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
expr.right,
normalizeUndefined
)})`;
}
if (expr.operator === '!=') {
result = `!${result}`;
}
return result;
} else {
return _default;
}
})
.with(P.union('?', '!', '^'), (op) => this.collectionPredicate(expr, op, normalizeUndefined))
.otherwise(() => _default);
}

private collectionPredicate(expr: BinaryExpr, operator: '?' | '!' | '^', normalizeUndefined: boolean) {
const operand = this.transform(expr.left, normalizeUndefined);
const innerTransformer = new TypeScriptExpressionTransformer({
...this.options,
fieldReferenceContext: '_item',
});
const predicate = innerTransformer.transform(expr.right, normalizeUndefined);

return match(operator)
.with('?', () => `!!((${operand})?.some((_item: any) => ${predicate}))`)
.with('!', () => `!!((${operand})?.every((_item: any) => ${predicate}))`)
.with('^', () => `!((${operand})?.some((_item: any) => ${predicate}))`)
.exhaustive();
}
}
Loading

0 comments on commit b8a875e

Please sign in to comment.