Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(aws-stepfunctions): refactor sagemaker tasks and fix default role issue #3014

Merged
merged 17 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
76ab6c3
fix(aws-stepfunctions) refactor and fix default role issue
mmcclean-aws Jun 23, 2019
1d8c043
Merge branch 'master' into sfn-sagemaker-fixes
mmcclean-aws Jun 23, 2019
c1e9140
fix(aws-stepfunctions) removed console log statements and fixed s3 pr…
mmcclean-aws Jun 23, 2019
d34847a
fix(aws-stepfunctions) removed construct from contructor for sagemake…
mmcclean-aws Jun 24, 2019
9286f27
Merge branch 'master' into sfn-sagemaker-fixes
mmcclean-aws Jun 24, 2019
1bce980
fix(aws-stepfunctions) renamed cdk core package reference
mmcclean-aws Jun 24, 2019
73ead6e
Merge branch 'master' into sfn-sagemaker-fixes
mmcclean-aws Jun 27, 2019
6c639b4
Merge branch 'master' into sfn-sagemaker-fixes
mmcclean-aws Jun 27, 2019
d60d604
Merge branch 'master' into sfn-sagemaker-fixes
mmcclean-aws Jul 5, 2019
159ba7a
Merge remote-tracking branch 'origin/master' into sfn-sagemaker-fixes
rix0rrr Aug 5, 2019
2a6f13c
Merge branch 'master' into sfn-sagemaker-fixes
rix0rrr Aug 8, 2019
017cff1
Merge remote-tracking branch 'origin/master' into sfn-sagemaker-fixes
rix0rrr Aug 19, 2019
167e81a
Merge branch 'sfn-sagemaker-fixes' of github.com:mattmcclean/aws-cdk …
rix0rrr Aug 19, 2019
39c13da
Merge remote-tracking branch 'origin/master' into sfn-sagemaker-fixes
rix0rrr Aug 19, 2019
b110216
Merge remote-tracking branch 'origin/master' into sfn-sagemaker-fixes
rix0rrr Aug 21, 2019
b00316f
Update tests
rix0rrr Aug 21, 2019
01392a7
Merge branch 'master' into sfn-sagemaker-fixes
mergify[bot] Aug 21, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,15 @@ export interface ResourceConfig {
* @experimental
*/
export interface VpcConfig {
/**
* VPC security groups.
*/
readonly securityGroups: ec2.ISecurityGroup[];

/**
* VPC id
*/
readonly vpc: ec2.Vpc;
readonly vpc: ec2.IVpc;

/**
* VPC subnets.
*/
readonly subnets: ec2.ISubnet[];
readonly subnets?: ec2.ISubnet[];
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ec2 = require('@aws-cdk/aws-ec2');
import iam = require('@aws-cdk/aws-iam');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration, Stack } from '@aws-cdk/cdk';
import { Construct, Duration, Lazy, Stack } from '@aws-cdk/cdk';
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig,
S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types';

Expand Down Expand Up @@ -45,7 +45,7 @@ export interface SagemakerTrainProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training.
Expand Down Expand Up @@ -76,7 +76,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
/**
* Allows specify security group connections for instances of this fleet.
*/
public readonly connections: ec2.Connections = new ec2.Connections();
public readonly connections: ec2.Connections;

/**
* The execution role for the Sagemaker training job.
Expand Down Expand Up @@ -105,6 +105,11 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
*/
private readonly stoppingCondition: StoppingCondition;

private readonly vpc: ec2.IVpc;
private readonly securityGroup: ec2.ISecurityGroup;
private readonly securityGroups: ec2.ISecurityGroup[] = [];
private readonly subnets: string[];

constructor(scope: Construct, private readonly props: SagemakerTrainProps) {
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved

// set the default resource config if not defined.
Expand All @@ -120,13 +125,18 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
};

// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
this.role = props.role || new iam.Role(scope, 'SagemakerTrainRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
]
});

// check that either algorithm name or image is defined
if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) {
throw new Error("Must define either an algorithm name or training image URI in the algorithm specification");
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
Expand All @@ -143,11 +153,27 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
});

// add the security groups to the connections object
if (this.props.vpcConfig) {
this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg));
if (props.vpcConfig) {
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
this.vpc = props.vpcConfig.vpc;
this.securityGroup = new ec2.SecurityGroup(scope, 'TrainJobSecurityGroup', {
vpc: this.vpc
});
this.connections = new ec2.Connections({ securityGroups: [this.securityGroup] });
this.securityGroups.push(this.securityGroup);
this.subnets = (props.vpcConfig.subnets) ? (props.vpcConfig.subnets.map(s => (s.subnetId))) : this.vpc.selectSubnets().subnetIds;
}
}

/**
* Add the security group to all instances via the launch configuration
* security groups array.
*
* @param securityGroup: The security group to add
*/
public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void {
this.securityGroups.push(securityGroup);
mattmcclean marked this conversation as resolved.
Show resolved Hide resolved
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
return {
resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.props.synchronous ? '.sync' : ''),
Expand Down Expand Up @@ -244,8 +270,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT

private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} {
return (config) ? { VpcConfig: {
SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )),
Subnets: config.subnets.map(subnet => ( subnet.subnetId )),
SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }),
Subnets: this.subnets,
}} : {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export interface SagemakerTransformProps {
/**
* Environment variables to set in the Docker container.
*/
readonly environment?: {[key: string]: any};
readonly environment?: {[key: string]: string};

/**
* Maximum number of parallel requests that can be sent to each instance in a transform job.
Expand All @@ -52,7 +52,7 @@ export interface SagemakerTransformProps {
/**
* Tags to be applied to the train job.
*/
readonly tags?: {[key: string]: any};
readonly tags?: {[key: string]: string};

/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
Expand Down Expand Up @@ -97,7 +97,7 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
constructor(scope: Construct, private readonly props: SagemakerTransformProps) {

// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
this.role = props.role || new iam.Role(scope, 'SagemakerTransformRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ let stack: cdk.Stack;
beforeEach(() => {
// GIVEN
stack = new cdk.Stack();
});
});

test('create basic training job', () => {
// WHEN
Expand Down Expand Up @@ -64,7 +64,7 @@ test('create basic training job', () => {
InstanceType: 'ml.m4.xlarge',
VolumeSizeInGB: 10
},
RoleArn: { "Fn::GetAtt": [ "SagemakerRole5FDB64E1", "Arn" ] },
RoleArn: { "Fn::GetAtt": [ "SagemakerTrainRoleCBF0A724", "Arn" ] },
StoppingCondition: {
MaxRuntimeInSeconds: 3600
},
Expand All @@ -87,7 +87,7 @@ test('create complex training job', () => {
],
});

const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, {
const trainTask = new tasks.SagemakerTrainTask(stack, {
trainingJobName: "MyTrainJob",
synchronous: true,
role,
Expand Down Expand Up @@ -148,9 +148,10 @@ test('create complex training job', () => {
vpcConfig: {
vpc,
subnets: vpc.privateSubnets,
securityGroups: [ securityGroup ]
}
})});
});
trainTask.addSecurityGroup(securityGroup);
const task = new sfn.Task(stack, 'TrainSagemaker', { task: trainTask });

// THEN
expect(stack.resolve(task.toStateJson())).toEqual({
Expand Down Expand Up @@ -213,7 +214,10 @@ test('create complex training job', () => {
{ Key: "Project", Value: "MyProject" }
],
VpcConfig: {
SecurityGroupIds: [ { "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] } ],
SecurityGroupIds: [
{ "Fn::GetAtt": [ "TrainJobSecurityGroupBECEDCDC", "GroupId" ] },
{ "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] },
],
Subnets: [
{ Ref: "VPCPrivateSubnet1Subnet8BCA10E0" },
{ Ref: "VPCPrivateSubnet2SubnetCFCDAA7A" },
Expand Down Expand Up @@ -300,3 +304,26 @@ test('pass param to training job', () => {
},
});
});

test('Cannot create a SageMaker train task with both algorithm name and image name missing', () => {

expect(() => new tasks.SagemakerTrainTask(stack, {
trainingJobName: 'myTrainJob',
algorithmSpecification: {},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Uri: sfn.Data.stringAt('$.S3Bucket')
}
}
}
],
outputDataConfig: {
s3OutputPath: 's3://mybucket/myoutputpath'
},
}))
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ beforeEach(() => {
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
],
});
});
});

test('create basic transform job', () => {
// WHEN
Expand Down