Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/workflowrun state handling #521

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

WORKFLOW_NAME = "TestWorkflow"

STATUS_DRAFT = "DRAFT"
STATUS_START = "READY"
STATUS_RUNNING = "RUNNING"
STATUS_END = "SUCCEEDED"
Expand Down Expand Up @@ -59,7 +60,7 @@ def create_primary(generic_payload, libraries):
portal_run_id="1234",
workflow=wf
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_FAIL]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_FAIL]:
StateFactory(workflow_run=wfr_1, status=state, payload=generic_payload)
for i in [0, 1, 2, 3]:
LibraryAssociation.objects.create(
Expand All @@ -75,7 +76,7 @@ def create_primary(generic_payload, libraries):
portal_run_id="1235",
workflow=wf
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_END]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_END]:
StateFactory(workflow_run=wfr_2, status=state, payload=generic_payload)
for i in [0, 1, 2, 3]:
LibraryAssociation.objects.create(
Expand All @@ -102,7 +103,7 @@ def create_secondary(generic_payload, libraries):
portal_run_id="2345",
workflow=wf_qc
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_END]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_END]:
StateFactory(workflow_run=wfr_qc_1, status=state, payload=generic_payload)
LibraryAssociation.objects.create(
workflow_run=wfr_qc_1,
Expand All @@ -117,7 +118,7 @@ def create_secondary(generic_payload, libraries):
portal_run_id="2346",
workflow=wf_qc
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_END]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_END]:
StateFactory(workflow_run=wfr_qc_2, status=state, payload=generic_payload)
LibraryAssociation.objects.create(
workflow_run=wfr_qc_2,
Expand All @@ -133,7 +134,7 @@ def create_secondary(generic_payload, libraries):
portal_run_id="3456",
workflow=wf_align
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_END]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_END]:
StateFactory(workflow_run=wfr_a, status=state, payload=generic_payload)
for i in [0, 1]:
LibraryAssociation.objects.create(
Expand All @@ -150,7 +151,7 @@ def create_secondary(generic_payload, libraries):
portal_run_id="4567",
workflow=wf_vc
)
for state in [STATUS_START, STATUS_RUNNING, STATUS_END]:
for state in [STATUS_DRAFT, STATUS_START, STATUS_RUNNING, STATUS_END]:
StateFactory(workflow_run=wfr_vc, status=state, payload=generic_payload)
for i in [0, 1]:
LibraryAssociation.objects.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from .workflow_run import WorkflowRun, LibraryAssociation
from .library import Library
from .state import State
from .state import Status
from .utils import WorkflowRunUtil
Original file line number Diff line number Diff line change
@@ -1,10 +1,86 @@
from django.db import models

from enum import Enum
from typing import List
from workflow_manager.models.base import OrcaBusBaseModel, OrcaBusBaseManager
from workflow_manager.models.workflow_run import WorkflowRun
from workflow_manager.models.payload import Payload


class Status(Enum):
DRAFT = "DRAFT", ['DRAFT', 'INITIAL', 'CREATED']
READY = "READY", ['READY']
RUNNING = "RUNNING", ['RUNNING', 'IN_PROGRESS']
SUCCEEDED = "SUCCEEDED", ['SUCCEEDED', 'SUCCESS']
FAILED = "FAILED", ['FAILED', 'FAILURE']
ABORTED = "ABORTED", ['ABORTED', 'CANCELLED', 'CANCELED']

def __init__(self, convention: str, aliases: List[str]):
self.convention = convention
self.aliases = aliases

def __str__(self):
return self.convention

@staticmethod
def get_convention(status: str):
# enforce upper case convention
status = status.upper()
status = status.replace("-", "_")
# TODO: handle other characters?
for s in Status:
if status in s.aliases:
return s.convention

# retain all uncontrolled states
return status

@staticmethod
def is_supported(status: str) -> bool:
# enforce upper case convention
status = status.upper()
for s in Status:
if status in s.aliases:
return True
return False

@staticmethod
def is_terminal(status: str) -> bool:
# enforce upper case convention
status = status.upper()
for s in [Status.SUCCEEDED, Status.FAILED, Status.ABORTED]:
if status in s.aliases:
return True
return False

@staticmethod
def is_draft(status: str) -> bool:
# enforce upper case convention
status = status.upper()
if status in Status.DRAFT.aliases:
return True
else:
return False

@staticmethod
def is_running(status: str) -> bool:
# enforce upper case convention
status = status.upper()
if status in Status.RUNNING.aliases:
return True
else:
return False

@staticmethod
def is_ready(status: str) -> bool:
# enforce upper case convention
status = status.upper()
if status in Status.READY.aliases:
return True
else:
return False


class StateManager(OrcaBusBaseManager):
pass

Expand All @@ -17,7 +93,7 @@ class Meta:

# --- mandatory fields
workflow_run = models.ForeignKey(WorkflowRun, on_delete=models.CASCADE)
status = models.CharField(max_length=255)
status = models.CharField(max_length=255) # TODO: How and where to enforce conventions?
timestamp = models.DateTimeField()

comment = models.CharField(max_length=255, null=True, blank=True)
Expand All @@ -40,3 +116,14 @@ def to_dict(self):
"payload": self.payload.to_dict() if (self.payload is not None) else None,
}

def is_terminal(self) -> bool:
return Status.is_terminal(str(self.status))

def is_draft(self) -> bool:
return Status.is_draft(str(self.status))

def is_ready(self) -> bool:
return Status.is_ready(str(self.status))

def is_running(self) -> bool:
return Status.is_running(str(self.status))
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import logging
from datetime import timedelta
from typing import List

from workflow_manager.models import Status, State, WorkflowRun

logger = logging.getLogger()
logger.setLevel(logging.INFO)

RUNNING_MIN_TIME_DELTA_SEC = timedelta(hours=1).total_seconds()
TIMEDELTA_1H = timedelta(hours=1)


class WorkflowRunUtil:
"""
Utility methods for a WorkflowRun.
# TODO: this could be integrated into the WorflowRun model class? (figure out performance / implications)
"""

def __init__(self, workflow_run: WorkflowRun):
self.workflow_run = workflow_run
self.states: List[State] = list(self.workflow_run.get_all_states())

def get_current_state(self):
if len(self.states) < 1:
return None
elif len(self.states) == 1:
return self.states[0]
else:
return WorkflowRunUtil.get_latest_state(self.states)

def is_complete(self):
return self.get_current_state().is_terminal()

def is_draft(self):
# There may be multiple DRAFT states. We assume they are in order, e.g. no other state inbetween
return self.get_current_state().is_draft()

def is_ready(self):
return self.get_current_state().is_ready()

def is_running(self):
return self.get_current_state().is_running()

def contains_status(self, status: str):
# NOTE: we assume status is following conventions
for s in self.states:
if status == s.status:
return True
return False

def transition_to(self, new_state: State) -> bool:
"""
Parameter:
new_state: the new state to transition to
Process:
Transition to the new state if possible and update the WorkflowRun.
Return:
False: if the transition is not possible
True: if the state was updated
# TODO: consider race conditions?
"""
# enforce status conventions on new state
new_state.status = Status.get_convention(new_state.status) # TODO: encapsulate into State ?!

# If it's a brand new WorkflowRun we expect the first state to be DRAFT
# TODO: handle exceptions;
# BCL Convert may not create a DRAFT state
if not self.get_current_state():
if new_state.is_draft():
self.persist_state(new_state)
return True
else:
logger.warning(f"WorkflowRun does not have state yet, but new state is not DRAFT: {new_state}")
self.persist_state(new_state) # FIXME: remove once convention is enforced
return True

# Ignore any state that's older than the current one
if new_state.timestamp < self.get_current_state().timestamp:
return False

# Don't allow any changes once in terminal state
if self.is_complete():
logger.info(f"WorkflowRun in terminal state, can't transition to: {new_state.status}")
return False

# Allowed transitions from DRAFT state
if self.is_draft():
if new_state.is_draft(): # allow "updates" of the DRAFT state
self.persist_state(new_state)
return True
elif new_state.is_ready(): # allow transition from DRAFT to READY state
self.persist_state(new_state)
return True
else:
return False # Don't allow any other transitions from DRAFT state

# Allowed transitions from READY state
if self.is_ready():
if new_state.is_draft(): # no going back
return False
if new_state.is_ready(): # no updates to READY state
return False
# Transitions to other states is allowed (may not be controlled states though, so we can't control)

# Allowed transitions from RUNNING state
if self.is_running():
if new_state.is_draft(): # no going back
return False
if new_state.is_ready(): # no going back
return False
if new_state.is_running():
# Only allow updates every so often
time_delta = new_state.timestamp - self.get_current_state().timestamp
if time_delta.total_seconds() < TIMEDELTA_1H.total_seconds():
# Avoid too frequent updates for RUNNING state
return False
else:
self.persist_state(new_state)
return True

# Allowed transitions from other state
if self.contains_status(new_state.status):
# Don't allow updates/duplications of other states
return False

# Assume other state transitions are OK
self.persist_state(new_state)
return True

def persist_state(self, new_state):
new_state.workflow_run = self.workflow_run
new_state.payload.save() # Need to save Payload before we can save State
new_state.save()

@staticmethod
def get_latest_state(states: List[State]) -> State:
last: State = states[0]
for s in states:
if s.timestamp > last.timestamp:
last = s
return last

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class WorkflowRunManager(OrcaBusBaseManager):


class WorkflowRun(OrcaBusBaseModel):

id = models.BigAutoField(primary_key=True)

# --- mandatory fields
Expand All @@ -35,7 +34,8 @@ class WorkflowRun(OrcaBusBaseModel):
objects = WorkflowRunManager()

def __str__(self):
return f"ID: {self.id}, portal_run_id: {self.portal_run_id}"
return f"ID: {self.id}, portal_run_id: {self.portal_run_id}, workflow_run_name: {self.workflow_run_name}, " \
f"workflow: {self.workflow.workflow_name} "

def to_dict(self):
return {
Expand All @@ -47,6 +47,10 @@ def to_dict(self):
"workflow": self.workflow.to_dict() if (self.workflow is not None) else None
}

def get_all_states(self):
# retrieve all states (DB records rather than a queryset)
return list(self.state_set.all()) # TODO: ensure order by timestamp ?


class LibraryAssociationManager(OrcaBusBaseManager):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ class WorkflowRunViewSet(ReadOnlyModelViewSet):
search_fields = WorkflowRun.get_base_fields()

def get_queryset(self):
print(self.request.query_params)
return WorkflowRun.objects.get_by_keyword(**self.request.query_params)
Loading