diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts index b120dd0e4351d..a73bc3eb6ce62 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -197,20 +197,15 @@ export interface ResourceConfig { * @experimental */ export interface VpcConfig { - /** - * VPC security groups. - */ - readonly securityGroups: ec2.ISecurityGroup[]; - /** * VPC id */ - readonly vpc: ec2.Vpc; + readonly vpc: ec2.IVpc; /** * VPC subnets. */ - readonly subnets: ec2.ISubnet[]; + readonly subnets?: ec2.SubnetSelection; } /** diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts index f4f062e55db29..9ca6d3af60a4e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -1,7 +1,7 @@ import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); -import { Construct, Duration, Stack } from '@aws-cdk/core'; +import { Duration, Lazy, Stack } from '@aws-cdk/core'; import { resourceArnSuffix } from './resource-arn-suffix'; import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types'; @@ -53,7 +53,7 @@ export interface SagemakerTrainTaskProps { /** * Tags to be applied to the train job. */ - readonly tags?: {[key: string]: any}; + readonly tags?: {[key: string]: string}; /** * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. @@ -88,15 +88,6 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn */ public readonly connections: ec2.Connections = new ec2.Connections(); - /** - * The execution role for the Sagemaker training job. - * - * @default new role for Amazon SageMaker to assume is automatically created. - */ - public readonly role: iam.IRole; - - public readonly grantPrincipal: iam.IPrincipal; - /** * The Algorithm Specification */ @@ -117,9 +108,15 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn */ private readonly stoppingCondition: StoppingCondition; + private readonly vpc: ec2.IVpc; + private securityGroup: ec2.ISecurityGroup; + private readonly securityGroups: ec2.ISecurityGroup[] = []; + private readonly subnets: string[]; private readonly integrationPattern: sfn.ServiceIntegrationPattern; + private _role?: iam.IRole; + private _grantPrincipal?: iam.IPrincipal; - constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) { + constructor(private readonly props: SagemakerTrainTaskProps) { this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; const supportedPatterns = [ @@ -143,8 +140,66 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn maxRuntime: Duration.hours(1) }; + // check that either algorithm name or image is defined + if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) { + throw new Error("Must define either an algorithm name or training image URI in the algorithm specification"); + } + + // set the input mode to 'File' if not defined + this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? + ( props.algorithmSpecification ) : + ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.inputDataConfig = props.inputDataConfig.map(config => { + if (!config.dataSource.s3DataSource.s3DataType) { + return Object.assign({}, config, { dataSource: { s3DataSource: + { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); + } else { + return config; + } + }); + + // add the security groups to the connections object + if (props.vpcConfig) { + this.vpc = props.vpcConfig.vpc; + this.subnets = (props.vpcConfig.subnets) ? + (this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds; + } + } + + /** + * The execution role for the Sagemaker training job. + * + * Only available after task has been added to a state machine. + */ + public get role(): iam.IRole { + if (this._role === undefined) { + throw new Error(`role not available yet--use the object in a Task first`); + } + return this._role; + } + + public get grantPrincipal(): iam.IPrincipal { + if (this._grantPrincipal === undefined) { + throw new Error(`Principal not available yet--use the object in a Task first`); + } + return this._grantPrincipal; + } + + /** + * Add the security group to all instances via the launch configuration + * security groups array. + * + * @param securityGroup: The security group to add + */ + public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void { + this.securityGroups.push(securityGroup); + } + + public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { // set the sagemaker role or create new one - this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', { assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), inlinePolicies: { CreateTrainingJob: new iam.PolicyDocument({ @@ -157,7 +212,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn 'logs:CreateLogGroup', 'logs:DescribeLogStreams', 'ecr:GetAuthorizationToken', - ...props.vpcConfig + ...this.props.vpcConfig ? [ 'ec2:CreateNetworkInterface', 'ec2:CreateNetworkInterfacePermission', @@ -178,36 +233,23 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn } }); - if (props.outputDataConfig.encryptionKey) { - props.outputDataConfig.encryptionKey.grantEncrypt(this.role); + if (this.props.outputDataConfig.encryptionKey) { + this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); } - if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) { - props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant'); + if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) { + this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant'); } - // set the input mode to 'File' if not defined - this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? - ( props.algorithmSpecification ) : - ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); - - // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.inputDataConfig = props.inputDataConfig.map(config => { - if (!config.dataSource.s3DataSource.s3DataType) { - return Object.assign({}, config, { dataSource: { s3DataSource: - { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); - } else { - return config; - } - }); - - // add the security groups to the connections object - if (this.props.vpcConfig) { - this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg)); + // create a security group if not defined + if (this.vpc && this.securityGroup === undefined) { + this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', { + vpc: this.vpc + }); + this.connections.addSecurityGroup(this.securityGroup); + this.securityGroups.push(this.securityGroup); } - } - public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { return { resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + resourceArnSuffix.get(this.integrationPattern), parameters: this.renderParameters(), @@ -218,7 +260,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn private renderParameters(): {[key: string]: any} { return { TrainingJobName: this.props.trainingJobName, - RoleArn: this.role.roleArn, + RoleArn: this._role!.roleArn, ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), ...(this.renderInputDataConfig(this.inputDataConfig)), ...(this.renderOutputDataConfig(this.props.outputDataConfig)), @@ -303,8 +345,8 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { return (config) ? { VpcConfig: { - SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )), - Subnets: config.subnets.map(subnet => ( subnet.subnetId )), + SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }), + Subnets: this.subnets, }} : {}; } @@ -330,7 +372,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn }), new iam.PolicyStatement({ actions: ['iam:PassRole'], - resources: [this.role.roleArn], + resources: [this._role!.roleArn], conditions: { StringEquals: { "iam:PassedToService": "sagemaker.amazonaws.com" } } diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts index eeabb1db2984c..33a2ccc90ef66 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts @@ -1,7 +1,7 @@ import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); -import { Construct, Stack } from '@aws-cdk/core'; +import { Stack } from '@aws-cdk/core'; import { resourceArnSuffix } from './resource-arn-suffix'; import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; @@ -37,7 +37,7 @@ export interface SagemakerTransformProps { /** * Environment variables to set in the Docker container. */ - readonly environment?: {[key: string]: any}; + readonly environment?: {[key: string]: string}; /** * Maximum number of parallel requests that can be sent to each instance in a transform job. @@ -57,7 +57,7 @@ export interface SagemakerTransformProps { /** * Tags to be applied to the train job. */ - readonly tags?: {[key: string]: any}; + readonly tags?: {[key: string]: string}; /** * Dataset to be transformed and the Amazon S3 location where it is stored. @@ -82,13 +82,6 @@ export interface SagemakerTransformProps { */ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { - /** - * The execution role for the Sagemaker training job. - * - * @default new role for Amazon SageMaker to assume is automatically created. - */ - public readonly role: iam.IRole; - /** * Dataset to be transformed and the Amazon S3 location where it is stored. */ @@ -98,10 +91,10 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { * ML compute instances for the transform job. */ private readonly transformResources: TransformResources; - private readonly integrationPattern: sfn.ServiceIntegrationPattern; + private _role?: iam.IRole; - constructor(scope: Construct, private readonly props: SagemakerTransformProps) { + constructor(private readonly props: SagemakerTransformProps) { this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; const supportedPatterns = [ @@ -114,12 +107,9 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { } // set the sagemaker role or create new one - this.role = props.role || new iam.Role(scope, 'SagemakerRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess') - ] - }); + if (props.role) { + this._role = props.role; + } // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : @@ -140,6 +130,16 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { } public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { + // create new role if doesn't exist + if (this._role === undefined) { + this._role = new iam.Role(task, 'SagemakerTransformRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess') + ] + }); + } + return { resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + resourceArnSuffix.get(this.integrationPattern), parameters: this.renderParameters(), @@ -147,6 +147,18 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { }; } + /** + * The execution role for the Sagemaker training job. + * + * Only available after task has been added to a state machine. + */ + public get role(): iam.IRole { + if (this._role === undefined) { + throw new Error(`role not available yet--use the object in a Task first`); + } + return this._role; + } + private renderParameters(): {[key: string]: any} { return { ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json index a005b7ccd93ef..0050a0140a282 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json @@ -51,7 +51,7 @@ "Principal": { "AWS": { "Fn::GetAtt": [ - "SagemakerRole5FDB64E1", + "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] } @@ -68,7 +68,7 @@ "Principal": { "AWS": { "Fn::GetAtt": [ - "SagemakerRole5FDB64E1", + "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] } @@ -79,8 +79,8 @@ "Version": "2012-10-17" } }, - "DeletionPolicy": "Delete", - "UpdateReplacePolicy": "Delete" + "UpdateReplacePolicy": "Delete", + "DeletionPolicy": "Delete" }, "TrainingData3FDB6D34": { "Type": "AWS::S3::Bucket", @@ -101,10 +101,10 @@ ] } }, - "DeletionPolicy": "Delete", - "UpdateReplacePolicy": "Delete" + "UpdateReplacePolicy": "Delete", + "DeletionPolicy": "Delete" }, - "SagemakerRole5FDB64E1": { + "TrainTaskSagemakerRole0A9B1CDD": { "Type": "AWS::IAM::Role", "Properties": { "AssumeRolePolicyDocument": { @@ -143,7 +143,7 @@ ] } }, - "SagemakerRoleDefaultPolicy9DD21C3C": { + "TrainTaskSagemakerRoleDefaultPolicyA28F72FA": { "Type": "AWS::IAM::Policy", "Properties": { "PolicyDocument": { @@ -238,10 +238,10 @@ ], "Version": "2012-10-17" }, - "PolicyName": "SagemakerRoleDefaultPolicy9DD21C3C", + "PolicyName": "TrainTaskSagemakerRoleDefaultPolicyA28F72FA", "Roles": [ { - "Ref": "SagemakerRole5FDB64E1" + "Ref": "TrainTaskSagemakerRole0A9B1CDD" } ] } @@ -322,7 +322,7 @@ "Effect": "Allow", "Resource": { "Fn::GetAtt": [ - "SagemakerRole5FDB64E1", + "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] } @@ -348,11 +348,11 @@ "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", { "Fn::GetAtt": [ - "SagemakerRole5FDB64E1", + "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] }, - "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", + "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"GRADIENT_ASCENT\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", { "Ref": "AWS::Region" }, @@ -389,4 +389,4 @@ } } } -} +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts index 8a72022d0f959..53e39e9a59b2d 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts @@ -18,9 +18,15 @@ const trainingData = new Bucket(stack, 'TrainingData', { new StateMachine(stack, 'StateMachine', { definition: new Task(stack, 'TrainTask', { - task: new SagemakerTrainTask(stack, { - algorithmSpecification: {}, - inputDataConfig: [{ channelName: 'InputData', dataSource: { s3DataSource: { s3Location: S3Location.fromBucket(trainingData, 'data/') } } }], + task: new SagemakerTrainTask({ + algorithmSpecification: { + algorithmName: 'GRADIENT_ASCENT', + }, + inputDataConfig: [{ channelName: 'InputData', dataSource: { + s3DataSource: { + s3Location: S3Location.fromBucket(trainingData, 'data/') + } + } }], outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, trainingJobName: 'MyTrainingJob', }) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts index 2c8e66ad20b16..9218b1f8d2c5e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -13,11 +13,11 @@ let stack: cdk.Stack; beforeEach(() => { // GIVEN stack = new cdk.Stack(); - }); +}); test('create basic training job', () => { // WHEN - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ trainingJobName: "MyTrainJob", algorithmSpecification: { algorithmName: "BlazingText", @@ -70,7 +70,7 @@ test('create basic training job', () => { InstanceType: 'ml.m4.xlarge', VolumeSizeInGB: 10 }, - RoleArn: { "Fn::GetAtt": [ "SagemakerRole5FDB64E1", "Arn" ] }, + RoleArn: { "Fn::GetAtt": [ "TrainSagemakerSagemakerRole89E8C593", "Arn" ] }, StoppingCondition: { MaxRuntimeInSeconds: 3600 }, @@ -81,7 +81,7 @@ test('create basic training job', () => { test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { expect(() => { - new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, trainingJobName: "MyTrainJob", algorithmSpecification: { @@ -118,7 +118,7 @@ test('create complex training job', () => { ], }); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + const trainTask = new tasks.SagemakerTrainTask({ trainingJobName: "MyTrainJob", integrationPattern: sfn.ServiceIntegrationPattern.SYNC, role, @@ -178,10 +178,10 @@ test('create complex training job', () => { }, vpcConfig: { vpc, - subnets: vpc.privateSubnets, - securityGroups: [ securityGroup ] } - })}); + }); + trainTask.addSecurityGroup(securityGroup); + const task = new sfn.Task(stack, 'TrainSagemaker', { task: trainTask }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ @@ -250,7 +250,10 @@ test('create complex training job', () => { { Key: "Project", Value: "MyProject" } ], VpcConfig: { - SecurityGroupIds: [ { "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] } ], + SecurityGroupIds: [ + { "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] }, + { "Fn::GetAtt": [ "TrainSagemakerTrainJobSecurityGroup7C858EB9", "GroupId" ] }, + ], Subnets: [ { Ref: "VPCPrivateSubnet1Subnet8BCA10E0" }, { Ref: "VPCPrivateSubnet2SubnetCFCDAA7A" }, @@ -269,7 +272,7 @@ test('pass param to training job', () => { ], }); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ trainingJobName: sfn.Data.stringAt('$.JobName'), role, algorithmSpecification: { @@ -339,3 +342,26 @@ test('pass param to training job', () => { }, }); }); + +test('Cannot create a SageMaker train task with both algorithm name and image name missing', () => { + + expect(() => new tasks.SagemakerTrainTask({ + trainingJobName: 'myTrainJob', + algorithmSpecification: {}, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket') + } + } + } + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/') + }, + })) + .toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/); +}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts index b89f0c90b8152..e92fa9bb99ac1 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -19,11 +19,11 @@ beforeEach(() => { iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess') ], }); - }); +}); test('create basic transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ transformJobName: "MyTransformJob", modelName: "MyModelName", transformInput: { @@ -67,7 +67,7 @@ test('create basic transform job', () => { test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { expect(() => { - new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, transformJobName: "MyTransformJob", modelName: "MyModelName", @@ -88,7 +88,7 @@ test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration patt test('create complex transform job', () => { // WHEN const kmsKey = new kms.Key(stack, 'Key'); - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ transformJobName: "MyTransformJob", modelName: "MyModelName", integrationPattern: sfn.ServiceIntegrationPattern.SYNC, @@ -161,7 +161,7 @@ test('create complex transform job', () => { test('pass param to transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ transformJobName: sfn.Data.stringAt('$.TransformJobName'), modelName: sfn.Data.stringAt('$.ModelName'), role,