Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add evaluteOnExit for aws batch retry #40

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/src/main/scala/cromwell/backend/backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ object CommonBackendConfigurationAttributes {
"default-runtime-attributes.docker",
"default-runtime-attributes.queueArn",
"default-runtime-attributes.awsBatchRetryAttempts",
"default-runtime-attributes.awsBatchEvaluateOnExit",
"default-runtime-attributes.ulimits",
"default-runtime-attributes.efsDelocalize",
"default-runtime-attributes.efsMakeMD5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ object AwsBatchAttributes {
"numSubmitAttempts",
"default-runtime-attributes.scriptBucketName",
"awsBatchRetryAttempts",
"awsBatchEvaluateOnExit",
"ulimits",
"efsDelocalize",
"efsMakeMD5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws
import scala.collection.mutable.ListBuffer
import cromwell.backend.BackendJobDescriptor
import cromwell.backend.io.JobPaths
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryStrategy, Ulimit, Volume}
import software.amazon.awssdk.services.batch.model.{ContainerProperties, EvaluateOnExit, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryAction, RetryStrategy, Ulimit, Volume}
import cromwell.backend.impl.aws.io.AwsBatchVolume

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -183,9 +183,30 @@ trait AwsBatchJobDefinitionBuilder {

def retryStrategyBuilder(context: AwsBatchJobDefinitionContext): (RetryStrategy.Builder, String) = {
// We can add here the 'evaluateOnExit' statement
(RetryStrategy.builder()
.attempts(context.runtimeAttributes.awsBatchRetryAttempts),
context.runtimeAttributes.awsBatchRetryAttempts.toString)
var builder = RetryStrategy.builder()
.attempts(context.runtimeAttributes.awsBatchRetryAttempts)

var evaluations: Seq[EvaluateOnExit] = Seq()
context.runtimeAttributes.awsBatchEvaluateOnExit.foreach(
(evaluate) => {
val evaluateBuilder = evaluate.foldLeft(EvaluateOnExit.builder()) {
case (acc, (k, v)) => (k.toLowerCase, v.toLowerCase) match {
case ("action", "retry") => acc.action(RetryAction.RETRY)
case ("action", "exit") => acc.action(RetryAction.EXIT)
case ("onexitcode", _) => acc.onExitCode(v)
case ("onreason", _) => acc.onReason(v)
case ("onstatusreason", _) => acc.onStatusReason(v)
case _ => acc
}
}
evaluations = evaluations :+ evaluateBuilder.build()
}
)

builder = builder.evaluateOnExit(evaluations.asJava)

(builder,
s"${context.runtimeAttributes.awsBatchRetryAttempts.toString}${context.runtimeAttributes.awsBatchEvaluateOnExit.toString}")
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ import wom.RuntimeAttributesKeys
import wom.format.MemorySize
import wom.types._
import wom.values._
import com.typesafe.config.{ConfigException,ConfigValueFactory}
import com.typesafe.config.{ConfigException, ConfigValueFactory}

import scala.util.matching.Regex
import org.slf4j.{Logger, LoggerFactory}

import scala.util.{Failure, Success, Try}
import scala.jdk.CollectionConverters._


/**
* Attributes that are provided to the job at runtime
* @param cpu number of vCPU
Expand All @@ -63,6 +67,7 @@ import org.slf4j.{Logger, LoggerFactory}
* @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed
* @param fileSystem the filesystem type, default is "s3"
* @param awsBatchRetryAttempts number of attempts that AWS Batch will retry the task if it fails
* @param awsBatchEvaluateOnExit Evaluate on exit strategy setting for AWS batch retry
* @param ulimits ulimit values to be passed to the container
* @param efsDelocalize should we delocalize efs files to s3
* @param efsMakeMD5 should we make a sibling md5 file as part of the job
Expand All @@ -78,6 +83,7 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
noAddress: Boolean,
scriptS3BucketName: String,
awsBatchRetryAttempts: Int,
awsBatchEvaluateOnExit: Vector[Map[String, String]],
ulimits: Vector[Map[String, String]],
efsDelocalize: Boolean,
efsMakeMD5 : Boolean,
Expand All @@ -91,6 +97,10 @@ object AwsBatchRuntimeAttributes {

val awsBatchRetryAttemptsKey = "awsBatchRetryAttempts"

val awsBatchEvaluateOnExitKey = "awsBatchEvaluateOnExit"
private val awsBatchEvaluateOnExitDefault = WomArray(WomArrayType(WomMapType(WomStringType,WomStringType)), Vector(WomMap(Map.empty[WomValue, WomValue])))


val awsBatchefsDelocalizeKey = "efsDelocalize"
val awsBatchefsMakeMD5Key = "efsMakeMD5"

Expand Down Expand Up @@ -157,6 +167,11 @@ object AwsBatchRuntimeAttributes {
.configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0)))
}

def awsBatchEvaluateOnExitValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Vector[Map[String, String]]] = {
AwsBatchEvaluateOnExitValidation
.withDefault(AwsBatchEvaluateOnExitValidation.fromConfig(runtimeConfig).getOrElse(awsBatchEvaluateOnExitDefault))
}

private def awsBatchefsDelocalizeValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Boolean] = {
AwsBatchefsDelocalizeValidation(awsBatchefsDelocalizeKey).withDefault(AwsBatchefsDelocalizeValidation(awsBatchefsDelocalizeKey)
.configDefaultWomValue(runtimeConfig).getOrElse(WomBoolean(false)))
Expand Down Expand Up @@ -199,7 +214,6 @@ object AwsBatchRuntimeAttributes {

def runtimeAttributesBuilder(configuration: AwsBatchConfiguration): StandardValidatedRuntimeAttributesBuilder = {
val runtimeConfig = aggregateDisksInRuntimeConfig(configuration)

def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
Expand All @@ -212,6 +226,7 @@ object AwsBatchRuntimeAttributes {
queueArnValidation(runtimeConfig),
scriptS3BucketNameValidation(runtimeConfig),
awsBatchRetryAttemptsValidation(runtimeConfig),
awsBatchEvaluateOnExitValidation(runtimeConfig),
ulimitsValidation(runtimeConfig),
awsBatchefsDelocalizeValidation(runtimeConfig),
awsBatchefsMakeMD5Validation(runtimeConfig)
Expand All @@ -227,6 +242,7 @@ object AwsBatchRuntimeAttributes {
dockerValidation,
queueArnValidation(runtimeConfig),
awsBatchRetryAttemptsValidation(runtimeConfig),
awsBatchEvaluateOnExitValidation(runtimeConfig),
ulimitsValidation(runtimeConfig),
awsBatchefsDelocalizeValidation(runtimeConfig),
awsBatchefsMakeMD5Validation(runtimeConfig)
Expand Down Expand Up @@ -254,6 +270,8 @@ object AwsBatchRuntimeAttributes {
case _ => ""
}
val awsBatchRetryAttempts: Int = RuntimeAttributesValidation.extract(awsBatchRetryAttemptsValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val awsBatchEvaluateOnExit: Vector[Map[String, String]] = RuntimeAttributesValidation.extract(awsBatchEvaluateOnExitValidation(runtimeAttrsConfig), validatedRuntimeAttributes)

val ulimits: Vector[Map[String, String]] = RuntimeAttributesValidation.extract(ulimitsValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val efsDelocalize: Boolean = RuntimeAttributesValidation.extract(awsBatchefsDelocalizeValidation(runtimeAttrsConfig),validatedRuntimeAttributes)
val efsMakeMD5: Boolean = RuntimeAttributesValidation.extract(awsBatchefsMakeMD5Validation(runtimeAttrsConfig),validatedRuntimeAttributes)
Expand All @@ -270,6 +288,7 @@ object AwsBatchRuntimeAttributes {
noAddress,
scriptS3BucketName,
awsBatchRetryAttempts,
awsBatchEvaluateOnExit,
ulimits,
efsDelocalize,
efsMakeMD5,
Expand Down Expand Up @@ -473,6 +492,115 @@ class AwsBatchRetryAttemptsValidation(key: String) extends IntRuntimeAttributesV
override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be an Integer"
}

object AwsBatchEvaluateOnExitValidation extends RuntimeAttributesValidation[Vector[Map[String, String]]] {

val requiredKey = "action"
private val acceptedKeys = Set(requiredKey, "onExitCode", "onReason", "onStatusReason")


def fromConfig(runtimeConfig: Option[Config]): Option[WomValue]= {
val config = runtimeConfig match {
case Some(value) => Try(value.getObjectList(key)) match {
case Failure(_) => None
case Success(value) => Some(value.asScala.map {
_.unwrapped().asScala.toMap
}.toList)
}
case _ => None
}

config match {
case Some(value) => Some(AwsBatchEvaluateOnExitValidation
.coercion collectFirst {
case womType if womType.coerceRawValue(value).isSuccess => womType.coerceRawValue(value).get
} getOrElse {
BadDefaultAttribute(WomString(value.toString))
})
case None => None
}
}

override def coercion: Iterable[WomType] = {
Set(WomStringType, WomArrayType(WomMapType(WomStringType, WomStringType)))
}

override protected def validateValue: PartialFunction[WomValue, ErrorOr[Vector[Map[String, String]]]] = {
case WomArray(womType, value)
if womType.memberType == WomMapType(WomStringType, WomStringType) =>
check_maps(value.toVector)
case WomMap(_, _) => "!!! ERROR1".invalidNel
}

private def check_maps(
maps: Vector[WomValue]
): ErrorOr[Vector[Map[String, String]]] = {
val entryNels: Vector[ErrorOr[Map[String, String]]] = maps.map {
case WomMap(_, value) => check_keys(value)
case _ => "!!! ERROR2".invalidNel
}
val sequenced: ErrorOr[Vector[Map[String, String]]] = sequenceNels(
entryNels
)
sequenced
}

private def validateActionKey(dict: Map[WomValue, WomValue]): ErrorOr[Map[String, String]] = {
val validCondition = Set("retry", "exit")
val convertedMap = dict
.map { case (WomString(k), WomString(v)) =>
(k, v)
// case _ => "!!! ERROR3".invalidNel
}
if (convertedMap.exists {
case (key, value) => key.toLowerCase == requiredKey && validCondition.contains(value.toLowerCase)
}) {
convertedMap.validNel
}
else {
s"Missing or invalid $requiredKey key/value for runtime attribute: $key. Refer to https://docs.aws.amazon.com/batch/latest/APIReference/API_RetryStrategy.html".invalidNel
}
}

private def check_keys(
dict: Map[WomValue, WomValue]
): ErrorOr[Map[String, String]] = {
val map_keys = dict.keySet.map(_.valueString.toLowerCase)
val unrecognizedKeys =
map_keys.diff(acceptedKeys.map(x => x.toLowerCase))
if (!dict.nonEmpty) {
Map.empty[String, String].validNel
}
else if (unrecognizedKeys.nonEmpty) {
s"Invalid keys in $key runtime attribute: $unrecognizedKeys. Only $acceptedKeys are accepted. Refer to https://docs.aws.amazon.com/batch/latest/APIReference/API_RetryStrategy.html".invalidNel
}
else {
validateActionKey(dict)
}
}

private def sequenceNels(
nels: Vector[ErrorOr[Map[String, String]]]
): ErrorOr[Vector[Map[String, String]]] = {
val emptyNel: ErrorOr[Vector[Map[String, String]]] =
Vector.empty[Map[String, String]].validNel
val seqNel: ErrorOr[Vector[Map[String, String]]] =
nels.foldLeft(emptyNel) { (acc, v) =>
(acc, v) mapN { (a, v) => a :+ v }
}
seqNel
}


override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be defined"

/**
* Returns the key of the runtime attribute.
*
* @return The key of the runtime attribute.
*/
override def key: String = AwsBatchRuntimeAttributes.awsBatchEvaluateOnExitKey
}

object AwsBatchefsDelocalizeValidation {
def apply(key: String): AwsBatchefsDelocalizeValidation = new AwsBatchefsDelocalizeValidation(key)
}
Expand Down Expand Up @@ -531,7 +659,7 @@ object UlimitsValidation
accepted_keys.diff(map_keys) union map_keys.diff(accepted_keys)

if (!dict.nonEmpty){
Map.empty[String, String].validNel
Map.empty[String, String].validNel
}else if (unrecognizedKeys.nonEmpty) {
s"Invalid keys in $key runtime attribute. Refer to 'ulimits' section on https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#containerProperties".invalidNel
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,37 @@ runtime {
}
```

### `awsBatchEvaluteOnExit`

*Default: _[]_* - will always retry

This runtime attribute sets the `evaluateOnExit` for [*AWS Batch Automated Job Retries*](https://docs.aws.amazon.com/batch/latest/userguide/job_retries.html) and specify the retry condition for a failed job.

This configuration works with `awsBatchRetryAttempts` and is useful if you only want to retry on certain failures.

For instance, if you will only like to retry during spot termination.

```
runtime {
awsBatchEvaluateOnExit: [
{
Action: "RETRY",
onStatusReason: "Host EC2*"
},
{
onReason : "*"
Action: "EXIT"
}
]
}
```

For more information on the batch retry strategy, please refer to:

* General Doc: [userguide/job_retries.html](https://docs.aws.amazon.com/batch/latest/userguide/job_retries.html)
* Blog: [Introducing retry strategies](https://aws.amazon.com/blogs/compute/introducing-retry-strategies-for-aws-batch/)


### `ulimits`

*Default: _empty_*
Expand Down
Loading
Loading