Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inject git commit information into workflow #482

Merged
merged 9 commits into from
Aug 2, 2024
11 changes: 10 additions & 1 deletion latch/resources/workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import inspect
import os
from dataclasses import is_dataclass
from textwrap import dedent
from typing import Callable, Dict, Union, get_args, get_origin

import click
import os
from flytekit import workflow as _workflow
from flytekit.core.workflow import PythonFunctionWorkflow

Expand Down Expand Up @@ -98,6 +98,15 @@ def decorator(f: Callable):
)
raise click.exceptions.Exit(1)

git_hash = os.environ.get("GIT_COMMIT_HASH")
is_dirty = os.environ.get("GIT_IS_DIRTY")

if git_hash is not None:
metadata._non_standard["git_commit_hash"] = git_hash
metadata._non_standard["git_is_dirty"] = (
False if is_dirty is None else is_dirty == "True"
)

_inject_metadata(f, metadata)

# note(aidan): used for only serialize_in_container
Expand Down
56 changes: 49 additions & 7 deletions latch_cli/centromere/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
generate_temporary_ssh_credentials,
hash_directory,
)
from latch_cli.workflow_config import AutoVersionMethod


@dataclass
Expand All @@ -52,7 +53,7 @@ class _CentromereCtx:
dkr_client: Optional[docker.APIClient] = None
ssh_client: Optional[paramiko.SSHClient] = None
pkg_root: Optional[Path] = None # root
disable_auto_version: bool = False
version_method: AutoVersionMethod = False
rahuldesai1 marked this conversation as resolved.
Show resolved Hide resolved
image_full = None
version = None
serialize_dir = None
Expand All @@ -75,11 +76,14 @@ class _CentromereCtx:
internal_ip: Optional[str] = None
username: Optional[str] = None

git_commit_hash: Optional[str] = None
git_is_dirty: bool = False

def __init__(
self,
pkg_root: Path,
*,
disable_auto_version: bool = False,
version_method: AutoVersionMethod = False,
rahuldesai1 marked this conversation as resolved.
Show resolved Hide resolved
remote: bool = False,
metadata_root: Optional[Path] = None,
snakefile: Optional[Path] = None,
Expand All @@ -88,7 +92,7 @@ def __init__(
):
self.use_new_centromere = use_new_centromere
self.remote = remote
self.disable_auto_version = disable_auto_version
self.version_method = version_method

try:
self.token = retrieve_or_login()
Expand Down Expand Up @@ -298,10 +302,48 @@ def __init__(
)
self.version = self.version.strip()

if not self.disable_auto_version:
hash = hash_directory(self.pkg_root)
self.version = f"{self.version}-{hash[:6]}"
click.echo(f" {self.version}\n")
from git import GitError, Repo
rahuldesai1 marked this conversation as resolved.
Show resolved Hide resolved

try:
repo = Repo(pkg_root)
self.git_commit_hash = repo.head.commit.hexsha
self.git_is_dirty = repo.is_dirty()
except GitError:
pass
except Exception as e:
click.secho(
"WARN: Exception occured while getting git hash from"
rahuldesai1 marked this conversation as resolved.
Show resolved Hide resolved
f" {self.pkg_root}: {e}",
fg="yellow",
)

if self.version_method != AutoVersionMethod.none:
if self.version_method == AutoVersionMethod.directory:
hash = hash_directory(self.pkg_root)[:6]
elif self.version_method == AutoVersionMethod.git:
if self.git_commit_hash is None:
click.secho(
dedent(f"""
Failed to extract git commit hash. Please ensure that
the project is a git repository and that the git executable
is available on your PATH.
"""),
fg="red",
)
raise click.exceptions.Exit(1)

hash = self.git_commit_hash[:6]
if self.git_is_dirty:
click.secho(
dedent("""
The git repository is dirty. The version will be suffixed
with '-wip' until the changes are committed or removed.
"""),
fg="yellow",
)
hash += "-wip"

self.version = f"{self.version}-{hash}"

if self.nucleus_check_version(self.version, self.workflow_name):
click.secho(
Expand Down
15 changes: 13 additions & 2 deletions latch_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
get_latest_package_version,
get_local_package_version,
)
from latch_cli.workflow_config import BaseImageOptions
from latch_cli.workflow_config import AutoVersionMethod, BaseImageOptions

latch_cli.click_utils.patch()

Expand Down Expand Up @@ -455,6 +455,13 @@ def execute(
" is called."
),
)
@click.option(
"--version-method",
help="Method to use when generating the workflow version",
type=EnumChoice(AutoVersionMethod, case_sensitive=False),
default="directory",
show_default=True,
)
@click.option(
"--remote/--no-remote",
is_flag=True,
Expand Down Expand Up @@ -526,6 +533,7 @@ def execute(
def register(
pkg_root: str,
disable_auto_version: bool,
version_method: AutoVersionMethod,
remote: bool,
docker_progress: str,
yes: bool,
Expand Down Expand Up @@ -554,11 +562,14 @@ def register(
)
raise click.exceptions.Exit(1)

if disable_auto_version:
version_method = AutoVersionMethod.none

from latch_cli.services.register import register

register(
pkg_root,
disable_auto_version=disable_auto_version,
version_method=version_method,
remote=remote,
skip_confirmation=yes,
open=open,
Expand Down
9 changes: 6 additions & 3 deletions latch_cli/services/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
upload_image,
)
from latch_cli.utils import WorkflowType
from latch_cli.workflow_config import AutoVersionMethod


def _delete_lines(num: int):
Expand Down Expand Up @@ -255,7 +256,7 @@ def _recursive_list(directory: Path) -> List[Path]:
def register(
pkg_root: str,
*,
disable_auto_version: bool = False,
version_method: AutoVersionMethod = AutoVersionMethod.directory,
remote: bool = False,
open: bool = False,
skip_confirmation: bool = False,
Expand Down Expand Up @@ -321,7 +322,7 @@ def register(

with _CentromereCtx(
Path(pkg_root),
disable_auto_version=disable_auto_version,
version_method=version_method,
remote=remote,
metadata_root=metadata_root,
snakefile=snakefile,
Expand Down Expand Up @@ -395,7 +396,9 @@ def register(
from ...snakemake.serialize import generate_jit_register_code
from ...snakemake.workflow import build_jit_register_wrapper

sm_jit_wf = build_jit_register_wrapper(cache_tasks)
sm_jit_wf = build_jit_register_wrapper(
cache_tasks, ctx.git_commit_hash, ctx.git_is_dirty
)
generate_jit_register_code(
sm_jit_wf,
ctx.pkg_root,
Expand Down
13 changes: 12 additions & 1 deletion latch_cli/services/register/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

import boto3
import click
import docker
import requests
from latch_sdk_config.latch import config
Expand Down Expand Up @@ -116,14 +117,24 @@ def upload_image(ctx: _CentromereCtx, image_name: str) -> List[str]:


def serialize_pkg_in_container(
ctx: _CentromereCtx, image_name: str, serialize_dir: str, wf_name_override: Optional[str] = None
ctx: _CentromereCtx,
image_name: str,
serialize_dir: str,
wf_name_override: Optional[str] = None,
) -> Tuple[List[str], str]:
assert ctx.dkr_client is not None

_env = {"LATCH_DKR_REPO": ctx.dkr_repo, "LATCH_VERSION": ctx.version}
if wf_name_override is not None:
_env["LATCH_WF_NAME_OVERRIDE"] = wf_name_override

if ctx.git_commit_hash is not None:
click.secho(
f"Tagging workflow version with git commit {ctx.git_commit_hash}", fg="blue"
)
_env["GIT_COMMIT_HASH"] = ctx.git_commit_hash
_env["GIT_IS_DIRTY"] = str(ctx.git_is_dirty)

_serialize_cmd = ["make", "serialize"]
container = ctx.dkr_client.create_container(
f"{ctx.dkr_repo}/{image_name}",
Expand Down
24 changes: 21 additions & 3 deletions latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,12 @@ def interface_to_parameters(
class JITRegisterWorkflow(WorkflowBase, ClassStorageTaskResolver):
out_parameter_name = "o0" # must be "o0"

def __init__(self, cache_tasks: bool = False):
def __init__(
self,
cache_tasks: bool = False,
git_commit_hash: Optional[str] = None,
git_is_dirty: bool = False,
):
self.cache_tasks = cache_tasks

assert metadata._snakemake_metadata is not None
Expand All @@ -336,6 +341,15 @@ def __init__(self, cache_tasks: bool = False):

desc = about_page_path.read_text()

if git_commit_hash is not None:
click.secho(
f"Tagging workflow version with git commit {git_commit_hash}", fg="blue"
)
metadata._snakemake_metadata._non_standard["git_commit_hash"] = (
git_commit_hash
)
metadata._snakemake_metadata._non_standard["git_is_dirty"] = git_is_dirty

docstring = Docstring(
f"{display_name}\n\n{desc}\n\n" + str(metadata._snakemake_metadata)
)
Expand Down Expand Up @@ -1014,8 +1028,12 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


def build_jit_register_wrapper(cache_tasks: bool = False) -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow(cache_tasks)
def build_jit_register_wrapper(
cache_tasks: bool = False,
git_commit_hash: Optional[str] = None,
git_is_dirty: bool = False,
) -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow(cache_tasks, git_commit_hash, git_is_dirty)
out_parameter_name = wrapper_wf.out_parameter_name

python_interface = wrapper_wf.python_interface
Expand Down
6 changes: 6 additions & 0 deletions latch_cli/workflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from latch_cli.constants import latch_constants


class AutoVersionMethod(Enum):
directory = "directory"
rahuldesai1 marked this conversation as resolved.
Show resolved Hide resolved
git = "git"
none = "none"


class BaseImageOptions(str, Enum):
default = "default"
cuda = "cuda"
Expand Down
Loading