diff --git a/packages/@aws-cdk-testing/framework-integ/test/aws-stepfunctions-tasks/test/bedrock/integ.invoke-model-guardrail-trace.ts b/packages/@aws-cdk-testing/framework-integ/test/aws-stepfunctions-tasks/test/bedrock/integ.invoke-model-guardrail-trace.ts index 6b75dbd7f2673..75b57bc08d2d0 100644 --- a/packages/@aws-cdk-testing/framework-integ/test/aws-stepfunctions-tasks/test/bedrock/integ.invoke-model-guardrail-trace.ts +++ b/packages/@aws-cdk-testing/framework-integ/test/aws-stepfunctions-tasks/test/bedrock/integ.invoke-model-guardrail-trace.ts @@ -38,7 +38,7 @@ const prompt = new BedrockInvokeModel(stack, 'Prompt', { }, }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: guardrail.attrGuardrailId, guardrailVersion: guardrail.attrVersion, }, diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md b/packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md index bd4a6e5c645dd..41d3422d1931d 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md @@ -357,7 +357,7 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', { }); ``` -You can apply a guardrail to the invocation by setting `guardrailConfiguration`. +You can apply a guardrail to the invocation by setting `guardrail`. ```ts import * as bedrock from 'aws-cdk-lib/aws-bedrock'; @@ -380,7 +380,7 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model with guardrail', { }, }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: myGuardrail.attrGuardrailId, guardrailVersion: myGuardrail.attrVersion, }, diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts index 94110db9a8a29..c09035e5d1203 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts @@ -42,9 +42,9 @@ export interface BedrockInvokeModelOutputProps { } /** - * Properties for the guardrail configuration. + * Properties for the guardrail. */ -export interface GuardrailConfiguration { +export interface Guardrail { /** * The unique identifier of the guardrail that you want to use. */ @@ -126,7 +126,7 @@ export interface BedrockInvokeModelProps extends sfn.TaskStateBaseProps { * * @default - No guardrail is applied to the invocation. */ - readonly guardrailConfiguration?: GuardrailConfiguration; + readonly guardrail?: Guardrail; /** * Specifies whether to enable or disable the Bedrock trace. @@ -173,7 +173,7 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { throw new Error('Output S3 object version is not supported.'); } - this.validateGuardrailConfiguration(props); + this.validateGuardrail(props); this.taskPolicies = this.renderPolicyStatements(); } @@ -220,7 +220,7 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { ); } - if (this.props.guardrailConfiguration) { + if (this.props.guardrail) { policyStatements.push( new iam.PolicyStatement({ actions: ['bedrock:ApplyGuardrail'], @@ -228,7 +228,7 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { Stack.of(this).formatArn({ service: 'bedrock', resource: 'guardrail', - resourceName: this.props.guardrailConfiguration.guardrailIdentifier, + resourceName: this.props.guardrail.guardrailIdentifier, }), ], }), @@ -257,8 +257,8 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { Output: this.props.output?.s3Location ? { S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`, } : undefined, - GuardrailIdentifier: this.props.guardrailConfiguration?.guardrailIdentifier, - GuardrailVersion: this.props.guardrailConfiguration?.guardrailVersion, + GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier, + GuardrailVersion: this.props.guardrail?.guardrailVersion, Trace: this.props.traceEnabled === undefined ? undefined : this.props.traceEnabled @@ -268,18 +268,18 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { }; } - private validateGuardrailConfiguration(props: BedrockInvokeModelProps) { - if (!props.guardrailConfiguration) return; + private validateGuardrail(props: BedrockInvokeModelProps) { + if (!props.guardrail) return; - const { guardrailIdentifier, guardrailVersion } = props.guardrailConfiguration; + const { guardrailIdentifier, guardrailVersion } = props.guardrail; if (!Token.isUnresolved(guardrailIdentifier)) { - const guardrailConfigurationPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/; - if (!guardrailConfigurationPattern.test(guardrailIdentifier)) { + const guardrailPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/; + if (!guardrailPattern.test(guardrailIdentifier)) { throw new Error(`You must set guardrailIdentifier to the id or the arn of Guardrail, got ${guardrailIdentifier}`); } if (props.contentType !== 'application/json') { - throw new Error(`You must set contentType to \'application/json\' when using guardrailConfiguration, got '${props.contentType}'.`); + throw new Error(`You must set contentType to \'application/json\' when using guardrail, got '${props.contentType}'.`); } if (guardrailIdentifier.length > 2048) { throw new Error(`\`guardrailIdentifier\` length must be between 0 and 2048, got ${guardrailIdentifier.length}.`); diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts index 1aa80f631ec41..1508ef9ac021a 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts @@ -361,7 +361,7 @@ describe('Invoke Model', () => { }).toThrow(/Output S3 object version is not supported./); }); - test('guardrail configuration', () => { + test('guardrail', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -375,7 +375,7 @@ describe('Invoke Model', () => { prompt: 'Hello world', }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: 'arn:aws:bedrock:us-turbo-2:123456789012:guardrail/testid', guardrailVersion: 'DRAFT', }, @@ -409,7 +409,7 @@ describe('Invoke Model', () => { }); }); - test('guardrail configuration fails when invalid guardrailIdentifier is set', () => { + test('guardrail fails when invalid guardrailIdentifier is set', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -424,7 +424,7 @@ describe('Invoke Model', () => { prompt: 'Hello world', }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: 'invalid-id', guardrailVersion: 'DRAFT', }, @@ -433,7 +433,7 @@ describe('Invoke Model', () => { }).toThrow('You must set guardrailIdentifier to the id or the arn of Guardrail, got invalid-id'); }); - test('guardrail configuration fails when guardrailIdentifier length is invalid', () => { + test('guardrail fails when guardrailIdentifier length is invalid', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -449,7 +449,7 @@ describe('Invoke Model', () => { prompt: 'Hello world', }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier, guardrailVersion: 'DRAFT', }, @@ -458,7 +458,7 @@ describe('Invoke Model', () => { }).toThrow(`\`guardrailIdentifier\` length must be between 0 and 2048, got ${guardrailIdentifier.length}.`); }); - test('guardrail configuration fails when invalid guardrailVersion is set', () => { + test('guardrail fails when invalid guardrailVersion is set', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -473,7 +473,7 @@ describe('Invoke Model', () => { prompt: 'Hello world', }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: 'abcdef', guardrailVersion: 'test', }, @@ -482,7 +482,7 @@ describe('Invoke Model', () => { }).toThrow('guardrailVersion must match the ^(([1-9][0-9]{0,7})|(DRAFT))$ pattern, got test'); }); - test('guardrail configuration fails when contentType is not \'application/json\'', () => { + test('guardrail fails when contentType is not \'application/json\'', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -497,13 +497,13 @@ describe('Invoke Model', () => { prompt: 'Hello world', }, ), - guardrailConfiguration: { + guardrail: { guardrailIdentifier: 'abcdef', guardrailVersion: 'DRAFT', }, }); // THEN - }).toThrow('You must set contentType to \'application/json\' when using guardrailConfiguration, got \'text/plain\'.'); + }).toThrow('You must set contentType to \'application/json\' when using guardrail, got \'text/plain\'.'); }); test('trace configuration', () => {