diff --git a/packages/plugins/openapi/src/generator.ts b/packages/plugins/openapi/src/generator.ts index f72e28cfc..e29b4ee68 100644 --- a/packages/plugins/openapi/src/generator.ts +++ b/packages/plugins/openapi/src/generator.ts @@ -1,7 +1,14 @@ // Inspired by: https://github.com/omar-dulaimi/prisma-trpc-generator import { DMMF } from '@prisma/generator-helper'; -import { AUXILIARY_FIELDS, getDataModels, hasAttribute, PluginError, PluginOptions } from '@zenstackhq/sdk'; +import { + analyzePolicies, + AUXILIARY_FIELDS, + getDataModels, + hasAttribute, + PluginError, + PluginOptions, +} from '@zenstackhq/sdk'; import { DataModel, isDataModel, type Model } from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypesForAggregate, @@ -201,11 +208,15 @@ export class OpenAPIGenerator { inputType?: object; outputType: object; successCode?: number; + security?: Array>; }; const definitions: OperationDefinition[] = []; const hasRelation = zmodel.fields.some((f) => isDataModel(f.type.reference?.ref)); + // analyze access policies to determine default security + const { create, read, update, delete: del } = analyzePolicies(zmodel); + if (ops['createOne']) { definitions.push({ method: 'post', @@ -225,6 +236,7 @@ export class OpenAPIGenerator { outputType: this.ref(model.name), description: `Create a new ${model.name}`, successCode: 201, + security: create === true ? [] : undefined, }); } @@ -245,6 +257,7 @@ export class OpenAPIGenerator { outputType: this.ref('BatchPayload'), description: `Create several ${model.name}`, successCode: 201, + security: create === true ? [] : undefined, }); } @@ -266,6 +279,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(model.name), description: `Find one unique ${model.name}`, + security: read === true ? [] : undefined, }); } @@ -287,6 +301,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(model.name), description: `Find the first ${model.name} matching the given condition`, + security: read === true ? [] : undefined, }); } @@ -308,6 +323,7 @@ export class OpenAPIGenerator { ), outputType: this.array(this.ref(model.name)), description: `Find a list of ${model.name}`, + security: read === true ? [] : undefined, }); } @@ -330,6 +346,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(model.name), description: `Update a ${model.name}`, + security: update === true ? [] : undefined, }); } @@ -350,6 +367,7 @@ export class OpenAPIGenerator { ), outputType: this.ref('BatchPayload'), description: `Update ${model.name}s matching the given condition`, + security: update === true ? [] : undefined, }); } @@ -373,6 +391,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(model.name), description: `Upsert a ${model.name}`, + security: create === true && update == true ? [] : undefined, }); } @@ -394,6 +413,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(model.name), description: `Delete one unique ${model.name}`, + security: del === true ? [] : undefined, }); } @@ -413,6 +433,7 @@ export class OpenAPIGenerator { ), outputType: this.ref('BatchPayload'), description: `Delete ${model.name}s matching the given condition`, + security: del === true ? [] : undefined, }); } @@ -433,6 +454,7 @@ export class OpenAPIGenerator { ), outputType: this.oneOf({ type: 'integer' }, this.ref(`${model.name}CountAggregateOutputType`)), description: `Find a list of ${model.name}`, + security: read === true ? [] : undefined, }); if (ops['aggregate']) { @@ -456,6 +478,7 @@ export class OpenAPIGenerator { ), outputType: this.ref(`Aggregate${model.name}`), description: `Aggregate ${model.name}s`, + security: read === true ? [] : undefined, }); } @@ -481,13 +504,14 @@ export class OpenAPIGenerator { ), outputType: this.array(this.ref(`${model.name}GroupByOutputType`)), description: `Group ${model.name}s by fields`, + security: read === true ? [] : undefined, }); } // get meta specified with @@openapi.meta const resourceMeta = getModelResourceMeta(zmodel); - for (const { method, operation, description, inputType, outputType, successCode } of definitions) { + for (const { method, operation, description, inputType, outputType, successCode, security } of definitions) { const meta = resourceMeta?.[operation]; if (meta?.ignore === true) { @@ -511,7 +535,8 @@ export class OpenAPIGenerator { description: meta?.description ?? description, tags: meta?.tags || [camelCase(model.name)], summary: meta?.summary, - security: meta?.security, + // security priority: operation-level > model-level > inferred + security: meta?.security ?? resourceMeta?.security ?? security, deprecated: meta?.deprecated, responses: { [successCode !== undefined ? successCode : '200']: { diff --git a/packages/plugins/openapi/src/meta.ts b/packages/plugins/openapi/src/meta.ts index ce2fff972..6a038938a 100644 --- a/packages/plugins/openapi/src/meta.ts +++ b/packages/plugins/openapi/src/meta.ts @@ -6,6 +6,7 @@ import { DataModel } from '@zenstackhq/sdk/ast'; */ export type ModelMeta = { tagDescription?: string; + security?: Array>; }; /** diff --git a/packages/plugins/openapi/tests/openapi.test.ts b/packages/plugins/openapi/tests/openapi.test.ts index ffef100d8..a0fb7d678 100644 --- a/packages/plugins/openapi/tests/openapi.test.ts +++ b/packages/plugins/openapi/tests/openapi.test.ts @@ -196,18 +196,52 @@ model User { ); }); - it('security override', async () => { + it('security model level override', async () => { const { model, dmmf, modelFile } = await loadZModelAndDmmf(` plugin openapi { provider = '${process.cwd()}/dist' + securitySchemes = { + myBasic: { type: 'http', scheme: 'basic' } + } +} + +model User { + id String @id + + @@openapi.meta({ + security: [] + }) +} + `); + + const { name: output } = tmp.fileSync({ postfix: '.yaml' }); + const options = buildOptions(model, modelFile, output); + await generate(model, options, dmmf); + + console.log('OpenAPI specification generated:', output); + + const api = await OpenAPIParser.validate(output); + expect(api.paths?.['/user/findMany']?.['get']?.security).toHaveLength(0); + }); + + it('security operation level override', async () => { + const { model, dmmf, modelFile } = await loadZModelAndDmmf(` +plugin openapi { + provider = '${process.cwd()}/dist' + securitySchemes = { + myBasic: { type: 'http', scheme: 'basic' } + } } model User { id String @id + @@allow('read', true) + @@openapi.meta({ + security: [], findMany: { - security: [] + security: [{ myBasic: [] }] } }) } @@ -220,7 +254,33 @@ model User { console.log('OpenAPI specification generated:', output); const api = await OpenAPIParser.validate(output); - expect(api.paths?.['/user/findMany']?.['get']?.security).toHaveLength(0); + expect(api.paths?.['/user/findMany']?.['get']?.security).toHaveLength(1); + }); + + it('security inferred', async () => { + const { model, dmmf, modelFile } = await loadZModelAndDmmf(` +plugin openapi { + provider = '${process.cwd()}/dist' + securitySchemes = { + myBasic: { type: 'http', scheme: 'basic' } + } +} + +model User { + id String @id + @@allow('create', true) +} + `); + + const { name: output } = tmp.fileSync({ postfix: '.yaml' }); + const options = buildOptions(model, modelFile, output); + await generate(model, options, dmmf); + + console.log('OpenAPI specification generated:', output); + + const api = await OpenAPIParser.validate(output); + expect(api.paths?.['/user/create']?.['post']?.security).toHaveLength(0); + expect(api.paths?.['/user/findMany']?.['get']?.security).toBeUndefined(); }); it('v3.1.0 fields', async () => { diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 1468311c8..4a0c6000d 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -6,13 +6,12 @@ import { isLiteralExpr, ReferenceExpr, } from '@zenstackhq/language/ast'; +import { analyzePolicies, getLiteral } from '@zenstackhq/sdk'; import { ValidationAcceptor } from 'langium'; -import { analyzePolicies } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getIdFields, getUniqueFields } from '../utils'; import { validateAttributeApplication, validateDuplicatedDeclarations } from './utils'; -import { getLiteral } from '@zenstackhq/sdk'; /** * Validates data model declarations. diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index c1979d590..15549e964 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -15,6 +15,7 @@ import { } from '@zenstackhq/language/ast'; import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; import { + analyzePolicies, getDataModels, getLiteral, GUARD_FIELD_NAME, @@ -29,7 +30,7 @@ import path from 'path'; import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; -import { analyzePolicies, getIdFields } from '../../utils/ast-utils'; +import { getIdFields } from '../../utils/ast-utils'; import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; import { ExpressionWriter } from './expression-writer'; import { isFutureExpr } from './utils'; diff --git a/packages/schema/src/plugins/access-policy/zod-schema-generator.ts b/packages/schema/src/plugins/access-policy/zod-schema-generator.ts index 222e17186..96d45ecc2 100644 --- a/packages/schema/src/plugins/access-policy/zod-schema-generator.ts +++ b/packages/schema/src/plugins/access-policy/zod-schema-generator.ts @@ -1,8 +1,7 @@ import { DataModel, DataModelField, DataModelFieldAttribute, isDataModelField } from '@zenstackhq/language/ast'; -import { AUXILIARY_FIELDS, getLiteral } from '@zenstackhq/sdk'; +import { AUXILIARY_FIELDS, VALIDATION_ATTRIBUTES, getLiteral } from '@zenstackhq/sdk'; import { camelCase } from 'change-case'; import { CodeBlockWriter } from 'ts-morph'; -import { VALIDATION_ATTRIBUTES } from '../../utils/ast-utils'; /** * Writes Zod schema for data models. diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 6d2c32e9b..d0045fbe6 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -20,6 +20,7 @@ import { Model, } from '@zenstackhq/language/ast'; import { + analyzePolicies, getLiteral, getLiteralArray, GUARD_FIELD_NAME, @@ -31,13 +32,13 @@ import { import fs from 'fs'; import { writeFile } from 'fs/promises'; import path from 'path'; -import { analyzePolicies } from '../../utils/ast-utils'; import { execSync } from '../../utils/exec-utils'; import { + ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, - ContainerAttribute as PrismaModelAttribute, ContainerDeclaration as PrismaContainerDeclaration, + Model as PrismaDataModel, DataSourceUrl as PrismaDataSourceUrl, Enum as PrismaEnum, FieldAttribute as PrismaFieldAttribute, @@ -45,10 +46,9 @@ import { FieldReferenceArg as PrismaFieldReferenceArg, FunctionCall as PrismaFunctionCall, FunctionCallArg as PrismaFunctionCallArg, - Model as PrismaDataModel, - ModelFieldType, - PassThroughAttribute as PrismaPassThroughAttribute, PrismaModel, + ContainerAttribute as PrismaModelAttribute, + PassThroughAttribute as PrismaPassThroughAttribute, SimpleField, } from './prisma-builder'; import ZModelCodeGenerator from './zmodel-code-generator'; diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index dd295ae4c..50fdf00d4 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,108 +1,17 @@ import { DataModel, - DataModelAttribute, DataModelField, Expression, isArrayExpr, - isDataModel, isDataModelField, isEnumField, isInvocationExpr, isMemberAccessExpr, isReferenceExpr, - Model, ReferenceExpr, } from '@zenstackhq/language/ast'; -import type { PolicyOperationKind } from '@zenstackhq/runtime'; -import { getLiteral } from '@zenstackhq/sdk'; import { isFromStdlib } from '../language-server/utils'; -export function extractDataModelsWithAllowRules(model: Model): DataModel[] { - return model.declarations.filter( - (d) => isDataModel(d) && d.attributes.some((attr) => attr.decl.ref?.name === '@@allow') - ) as DataModel[]; -} - -export function analyzePolicies(dataModel: DataModel) { - const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); - const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); - - const create = toStaticPolicy('create', allows, denies); - const read = toStaticPolicy('read', allows, denies); - const update = toStaticPolicy('update', allows, denies); - const del = toStaticPolicy('delete', allows, denies); - const hasFieldValidation = dataModel.fields.some((field) => - field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText)) - ); - - return { - allows, - denies, - create, - read, - update, - delete: del, - allowAll: create === true && read === true && update === true && del === true, - denyAll: create === false && read === false && update === false && del === false, - hasFieldValidation, - }; -} - -function toStaticPolicy( - operation: PolicyOperationKind, - allows: DataModelAttribute[], - denies: DataModelAttribute[] -): boolean | undefined { - const filteredDenies = forOperation(operation, denies); - if (filteredDenies.some((rule) => getLiteral(rule.args[1].value) === true)) { - // any constant true deny rule - return false; - } - - const filteredAllows = forOperation(operation, allows); - if (filteredAllows.length === 0) { - // no allow rule - return false; - } - - if ( - filteredDenies.length === 0 && - filteredAllows.some((rule) => getLiteral(rule.args[1].value) === true) - ) { - // any constant true allow rule - return true; - } - return undefined; -} - -function forOperation(operation: PolicyOperationKind, rules: DataModelAttribute[]) { - return rules.filter((rule) => { - const ops = getLiteral(rule.args[0].value); - if (!ops) { - return false; - } - if (ops === 'all') { - return true; - } - const splitOps = ops.split(',').map((p) => p.trim()); - return splitOps.includes(operation); - }); -} - -export const VALIDATION_ATTRIBUTES = [ - '@length', - '@regex', - '@startsWith', - '@endsWith', - '@email', - '@url', - '@datetime', - '@gt', - '@gte', - '@lt', - '@lte', -]; - export function getIdFields(dataModel: DataModel) { const fieldLevelId = dataModel.fields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id')); if (fieldLevelId) { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 130e9ea8d..e82960d44 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -2,3 +2,4 @@ export * from './code-gen'; export * from './constants'; export * from './types'; export * from './utils'; +export * from './policy'; diff --git a/packages/sdk/src/policy.ts b/packages/sdk/src/policy.ts new file mode 100644 index 000000000..ef10fa633 --- /dev/null +++ b/packages/sdk/src/policy.ts @@ -0,0 +1,82 @@ +import type { DataModel, DataModelAttribute } from './ast'; +import { getLiteral } from './utils'; + +export const VALIDATION_ATTRIBUTES = [ + '@length', + '@regex', + '@startsWith', + '@endsWith', + '@email', + '@url', + '@datetime', + '@gt', + '@gte', + '@lt', + '@lte', +]; + +export function analyzePolicies(dataModel: DataModel) { + const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); + const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); + + const create = toStaticPolicy('create', allows, denies); + const read = toStaticPolicy('read', allows, denies); + const update = toStaticPolicy('update', allows, denies); + const del = toStaticPolicy('delete', allows, denies); + const hasFieldValidation = dataModel.fields.some((field) => + field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText)) + ); + + return { + allows, + denies, + create, + read, + update, + delete: del, + allowAll: create === true && read === true && update === true && del === true, + denyAll: create === false && read === false && update === false && del === false, + hasFieldValidation, + }; +} + +function toStaticPolicy( + operation: string, + allows: DataModelAttribute[], + denies: DataModelAttribute[] +): boolean | undefined { + const filteredDenies = forOperation(operation, denies); + if (filteredDenies.some((rule) => getLiteral(rule.args[1].value) === true)) { + // any constant true deny rule + return false; + } + + const filteredAllows = forOperation(operation, allows); + if (filteredAllows.length === 0) { + // no allow rule + return false; + } + + if ( + filteredDenies.length === 0 && + filteredAllows.some((rule) => getLiteral(rule.args[1].value) === true) + ) { + // any constant true allow rule + return true; + } + return undefined; +} + +function forOperation(operation: string, rules: DataModelAttribute[]) { + return rules.filter((rule) => { + const ops = getLiteral(rule.args[0].value); + if (!ops) { + return false; + } + if (ops === 'all') { + return true; + } + const splitOps = ops.split(',').map((p) => p.trim()); + return splitOps.includes(operation); + }); +} diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index d27a35da2..638f27312 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -51,7 +51,7 @@ export function getLiteralArray< if (!arr) { return undefined; } - return arr.map((item) => getLiteral(item)); + return arr.map((item) => getLiteral(item) ?? getObjectLiteral(item)); } export function getObjectLiteral(expr: Expression | undefined): T | undefined {