Skip to content

Commit

Permalink
chore: more CLI options (zenstackhq#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Sep 17, 2023
1 parent eba3390 commit e8f7a2d
Show file tree
Hide file tree
Showing 18 changed files with 432 additions and 137 deletions.
1 change: 1 addition & 0 deletions packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"semver": "^7.3.8",
"sleep-promise": "^9.1.0",
"strip-color": "^0.1.0",
"tiny-invariant": "^1.3.1",
"ts-morph": "^16.0.0",
"ts-pattern": "^4.3.0",
"upper-case-first": "^2.0.2",
Expand Down
17 changes: 10 additions & 7 deletions packages/schema/src/cli/actions/generate.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { PluginError } from '@zenstackhq/sdk';
import colors from 'colors';
import path from 'path';
import { Context } from '../../types';
import { PackageManagers } from '../../utils/pkg-utils';
import { CliError } from '../cli-error';
import {
checkNewVersion,
Expand All @@ -11,13 +9,15 @@ import {
loadDocument,
requiredPrismaVersion,
} from '../cli-util';
import { PluginRunner } from '../plugin-runner';
import { PluginRunner, PluginRunnerOptions } from '../plugin-runner';

type Options = {
schema: string;
packageManager: PackageManagers | undefined;
output?: string;
dependencyCheck: boolean;
versionCheck: boolean;
compile: boolean;
defaultPlugins: boolean;
};

/**
Expand Down Expand Up @@ -53,14 +53,17 @@ export async function generate(projectPath: string, options: Options) {

async function runPlugins(options: Options) {
const model = await loadDocument(options.schema);
const context: Context = {

const runnerOpts: PluginRunnerOptions = {
schema: model,
schemaPath: path.resolve(options.schema),
outDir: path.dirname(options.schema),
defaultPlugins: options.defaultPlugins,
output: options.output,
compile: options.compile,
};

try {
await new PluginRunner().run(context);
await new PluginRunner().run(runnerOpts);
} catch (err) {
if (err instanceof PluginError) {
console.error(colors.red(`${err.plugin}: ${err.message}`));
Expand Down
6 changes: 4 additions & 2 deletions packages/schema/src/cli/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export function createProgram() {
.addOption(configOption)
.addOption(pmOption)
.addOption(new Option('--prisma <file>', 'location of Prisma schema file to bootstrap from'))
.addOption(new Option('--tag [tag]', 'the NPM package tag to use when installing dependencies'))
.addOption(new Option('--tag <tag>', 'the NPM package tag to use when installing dependencies'))
.addOption(noVersionCheckOption)
.argument('[path]', 'project path', '.')
.action(initAction);
Expand All @@ -90,8 +90,10 @@ export function createProgram() {
.command('generate')
.description('Run code generation.')
.addOption(schemaOption)
.addOption(new Option('-o, --output <path>', 'default output directory for built-in plugins'))
.addOption(configOption)
.addOption(pmOption)
.addOption(new Option('--no-default-plugins', 'do not run default plugins'))
.addOption(new Option('--no-compile', 'do not compile the output of built-in plugins'))
.addOption(noVersionCheckOption)
.addOption(noDependencyCheck)
.action(generateAction);
Expand Down
130 changes: 85 additions & 45 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-var-requires */
import type { DMMF } from '@prisma/generator-helper';
import { isPlugin, Plugin } from '@zenstackhq/language/ast';
import { isPlugin, Model, Plugin } from '@zenstackhq/language/ast';
import {
getDataModels,
getDMMF,
Expand All @@ -19,9 +19,7 @@ import ora from 'ora';
import path from 'path';
import { ensureDefaultOutputFolder } from '../plugins/plugin-utils';
import telemetry from '../telemetry';
import type { Context } from '../types';
import { getVersion } from '../utils/version-utils';
import { config } from './config';

type PluginInfo = {
name: string;
Expand All @@ -32,23 +30,31 @@ type PluginInfo = {
module: any;
};

export type PluginRunnerOptions = {
schema: Model;
schemaPath: string;
output?: string;
defaultPlugins: boolean;
compile: boolean;
};

/**
* ZenStack plugin runner
*/
export class PluginRunner {
/**
* Runs a series of nested generators
*/
async run(context: Context): Promise<void> {
async run(options: PluginRunnerOptions): Promise<void> {
const version = getVersion();
console.log(colors.bold(`⌛️ ZenStack CLI v${version}, running plugins`));

ensureDefaultOutputFolder();
ensureDefaultOutputFolder(options);

const plugins: PluginInfo[] = [];
const pluginDecls = context.schema.declarations.filter((d): d is Plugin => isPlugin(d));
const pluginDecls = options.schema.declarations.filter((d): d is Plugin => isPlugin(d));

let prismaOutput = resolvePath('./prisma/schema.prisma', { schemaPath: context.schemaPath, name: '' });
let prismaOutput = resolvePath('./prisma/schema.prisma', { schemaPath: options.schemaPath, name: '' });

for (const pluginDecl of pluginDecls) {
const pluginProvider = this.getPluginProvider(pluginDecl);
Expand All @@ -73,59 +79,35 @@ export class PluginRunner {

const dependencies = this.getPluginDependencies(pluginModule);
const pluginName = this.getPluginName(pluginModule, pluginProvider);
const options: PluginOptions = { schemaPath: context.schemaPath, name: pluginName };
const pluginOptions: PluginOptions = { schemaPath: options.schemaPath, name: pluginName };

pluginDecl.fields.forEach((f) => {
const value = getLiteral(f.value) ?? getLiteralArray(f.value);
if (value === undefined) {
throw new PluginError(pluginName, `Invalid option value for ${f.name}`);
}
options[f.name] = value;
pluginOptions[f.name] = value;
});

plugins.push({
name: pluginName,
provider: pluginProvider,
dependencies,
options,
options: pluginOptions,
run: pluginModule.default as PluginFunction,
module: pluginModule,
});

if (pluginProvider === '@core/prisma' && typeof options.output === 'string') {
if (pluginProvider === '@core/prisma' && typeof pluginOptions.output === 'string') {
// record custom prisma output path
prismaOutput = resolvePath(options.output, options);
prismaOutput = resolvePath(pluginOptions.output, pluginOptions);
}
}

// make sure prerequisites are included
const corePlugins: Array<{ provider: string; options?: Record<string, unknown> }> = [
{ provider: '@core/prisma' },
{ provider: '@core/model-meta' },
{ provider: '@core/access-policy' },
];

if (getDataModels(context.schema).some((model) => hasValidationAttributes(model))) {
// '@core/zod' plugin is auto-enabled if there're validation rules
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
} else {
// add core dependency
corePlugins.push({ provider: dep });
}
}
});
// get core plugins that need to be enabled
const corePlugins = this.calculateCorePlugins(options, plugins);

// shift/insert core plugins to the front
for (const corePlugin of corePlugins.reverse()) {
const existingIdx = plugins.findIndex((p) => p.provider === corePlugin.provider);
if (existingIdx >= 0) {
Expand All @@ -141,7 +123,7 @@ export class PluginRunner {
name: pluginName,
provider: corePlugin.provider,
dependencies: [],
options: { schemaPath: context.schemaPath, name: pluginName, ...corePlugin.options },
options: { schemaPath: options.schemaPath, name: pluginName, ...corePlugin.options },
run: pluginModule.default,
module: pluginModule,
});
Expand All @@ -161,12 +143,17 @@ export class PluginRunner {
}
}

if (plugins.length === 0) {
console.log(colors.yellow('No plugins configured.'));
return;
}

const warnings: string[] = [];

let dmmf: DMMF.Document | undefined = undefined;
for (const { name, provider, run, options } of plugins) {
for (const { name, provider, run, options: pluginOptions } of plugins) {
// const start = Date.now();
await this.runPlugin(name, run, context, options, dmmf, warnings);
await this.runPlugin(name, run, options, pluginOptions, dmmf, warnings);
// console.log(`✅ Plugin ${colors.bold(name)} (${provider}) completed in ${Date.now() - start}ms`);
if (provider === '@core/prisma') {
// load prisma DMMF
Expand All @@ -175,14 +162,64 @@ export class PluginRunner {
});
}
}

console.log(colors.green(colors.bold('\n👻 All plugins completed successfully!')));

warnings.forEach((w) => console.warn(colors.yellow(w)));

console.log(`Don't forget to restart your dev server to let the changes take effect.`);
}

private calculateCorePlugins(options: PluginRunnerOptions, plugins: PluginInfo[]) {
const corePlugins: Array<{ provider: string; options?: Record<string, unknown> }> = [];

if (options.defaultPlugins) {
corePlugins.push(
{ provider: '@core/prisma' },
{ provider: '@core/model-meta' },
{ provider: '@core/access-policy' }
);
} else if (plugins.length > 0) {
// "@core/prisma" plugin is always enabled if any plugin is configured
corePlugins.push({ provider: '@core/prisma' });
}

// "@core/access-policy" has implicit requirements
if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) {
// make sure "@core/model-meta" is enabled
if (!corePlugins.find((p) => p.provider === '@core/model-meta')) {
corePlugins.push({ provider: '@core/model-meta' });
}

// '@core/zod' plugin is auto-enabled by "@core/access-policy"
// if there're validation rules
if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) {
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
} else {
// add core dependency
corePlugins.push({ provider: dep });
}
}
});

return corePlugins;
}

private hasValidation(schema: Model) {
return getDataModels(schema).some((model) => hasValidationAttributes(model));
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
private getPluginName(pluginModule: any, pluginProvider: string): string {
return typeof pluginModule.name === 'string' ? (pluginModule.name as string) : pluginProvider;
Expand All @@ -200,7 +237,7 @@ export class PluginRunner {
private async runPlugin(
name: string,
run: PluginFunction,
context: Context,
runnerOptions: PluginRunnerOptions,
options: PluginOptions,
dmmf: DMMF.Document | undefined,
warnings: string[]
Expand All @@ -216,7 +253,10 @@ export class PluginRunner {
options,
},
async () => {
let result = run(context.schema, options, dmmf, config);
let result = run(runnerOptions.schema, options, dmmf, {
output: runnerOptions.output,
compile: runnerOptions.compile,
});
if (result instanceof Promise) {
result = await result;
}
Expand Down
11 changes: 6 additions & 5 deletions packages/schema/src/plugins/access-policy/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Model } from '@zenstackhq/language/ast';
import { PluginOptions } from '@zenstackhq/sdk';
import { PluginFunction } from '@zenstackhq/sdk';
import PolicyGenerator from './policy-guard-generator';

export const name = 'Access Policy';

export default async function run(model: Model, options: PluginOptions) {
return new PolicyGenerator().generate(model, options);
}
const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {
return new PolicyGenerator().generate(model, options, globalOptions);
};

export default run;
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
import {
ExpressionContext,
PluginError,
PluginGlobalOptions,
PluginOptions,
RUNTIME_PACKAGE,
analyzePolicies,
Expand Down Expand Up @@ -65,8 +66,8 @@ import { ExpressionWriter, FALSE, TRUE } from './expression-writer';
* Generates source file that contains Prisma query guard objects used for injecting database queries
*/
export default class PolicyGenerator {
async generate(model: Model, options: PluginOptions) {
let output = options.output ? (options.output as string) : getDefaultOutputFolder();
async generate(model: Model, options: PluginOptions, globalOptions?: PluginGlobalOptions) {
let output = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions);
if (!output) {
throw new PluginError(options.name, `Unable to determine output path, not running plugin`);
}
Expand Down Expand Up @@ -147,7 +148,14 @@ export default class PolicyGenerator {

sf.addStatements('export default policy');

const shouldCompile = options.compile !== false;
let shouldCompile = true;
if (typeof options.compile === 'boolean') {
// explicit override
shouldCompile = options.compile;
} else if (globalOptions) {
shouldCompile = globalOptions.compile;
}

if (!shouldCompile || options.preserveTsFiles === true) {
// save ts files
await saveProject(project);
Expand Down
Loading

0 comments on commit e8f7a2d

Please sign in to comment.