Skip to content

Commit

Permalink
KEP-2170: Create model and dataset initializers (#2303)
Browse files Browse the repository at this point in the history
* KEP-2170: Create model and dataset initializers

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add abstract classes

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add storage URI to config

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Update .gitignore

Co-authored-by: Kevin Hannon <kehannon@redhat.com>
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Fix the misspelling for initializer

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

* Add .pt and .pth to ignore_patterns

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>

---------

Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
Co-authored-by: Kevin Hannon <kehannon@redhat.com>
  • Loading branch information
andreyvelich and kannon92 authored Oct 27, 2024
1 parent 7659239 commit 3f7ec16
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/publish-core-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ jobs:
dockerfile: cmd/training-operator.v2alpha1/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
tag-prefix: v2alpha1
- component-name: model-initializer-v2
dockerfile: cmd/initializer_v2/model/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: dataset-initializer-v2
dockerfile: cmd/initializer_v2/dataset/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: kubectl-delivery
dockerfile: build/images/kubectl-delivery/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ cover.out
.vscode/
__debug_bin

# Compiled python files.
*.pyc
# Python cache files
__pycache__/

# Emacs temporary files
*~
Expand Down
13 changes: 13 additions & 0 deletions cmd/initializer_v2/dataset/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine

WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initializer_v2/dataset/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initializer_v2 pkg/initializer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initializer_v2.dataset"]
1 change: 1 addition & 0 deletions cmd/initializer_v2/dataset/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
13 changes: 13 additions & 0 deletions cmd/initializer_v2/model/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine

WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initializer_v2/model/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initializer_v2 pkg/initializer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initializer_v2.model"]
1 change: 1 addition & 0 deletions cmd/initializer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
31 changes: 31 additions & 0 deletions pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.dataset.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting dataset initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

match urlparse(storage_uri).scheme:
# TODO (andreyvelich): Implement more dataset providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_dataset()
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception
9 changes: 9 additions & 0 deletions pkg/initializer_v2/dataset/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceDatasetConfig:
storage_uri: str
access_token: Optional[str] = None
42 changes: 42 additions & 0 deletions pkg/initializer_v2/dataset/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from urllib.parse import urlparse

import huggingface_hub

import pkg.initializer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initializer_v2.dataset.config import HuggingFaceDatasetConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace(utils.DatasetProvider):

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceDatasetConfig)
logging.info(f"Config for HuggingFace dataset initializer: {config_dict}")
self.config = HuggingFaceDatasetConfig(**config_dict)

def download_dataset(self):
storage_uri_parsed = urlparse(self.config.storage_uri)
dataset_uri = storage_uri_parsed.netloc + storage_uri_parsed.path

logging.info(f"Downloading dataset: {dataset_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

huggingface_hub.snapshot_download(
repo_id=dataset_uri,
repo_type="dataset",
local_dir=constants.VOLUME_PATH_DATASET,
)

logging.info("Dataset has been downloaded")
33 changes: 33 additions & 0 deletions pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.model.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting pre-trained model initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

match urlparse(storage_uri).scheme:
# TODO (andreyvelich): Implement more model providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_model()
case _:
logging.error(
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception
9 changes: 9 additions & 0 deletions pkg/initializer_v2/model/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceModelInputConfig:
storage_uri: str
access_token: Optional[str] = None
47 changes: 47 additions & 0 deletions pkg/initializer_v2/model/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from urllib.parse import urlparse

import huggingface_hub

import pkg.initializer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initializer_v2.model.config import HuggingFaceModelInputConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace(utils.ModelProvider):

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceModelInputConfig)
logging.info(f"Config for HuggingFace model initializer: {config_dict}")
self.config = HuggingFaceModelInputConfig(**config_dict)

def download_model(self):
storage_uri_parsed = urlparse(self.config.storage_uri)
model_uri = storage_uri_parsed.netloc + storage_uri_parsed.path

logging.info(f"Downloading model: {model_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

# TODO (andreyvelich): We should consider to follow vLLM approach with allow patterns.
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815913663
# TODO (andreyvelich): We should update patterns for Mistral model
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815914270
huggingface_hub.snapshot_download(
repo_id=model_uri,
local_dir=constants.VOLUME_PATH_MODEL,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
)

logging.info("Model has been downloaded")
Empty file.
37 changes: 37 additions & 0 deletions pkg/initializer_v2/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
from abc import ABC, abstractmethod
from dataclasses import fields
from typing import Dict

STORAGE_URI_ENV = "STORAGE_URI"
HF_SCHEME = "hf"


class ModelProvider(ABC):
@abstractmethod
def load_config(self):
raise NotImplementedError()

@abstractmethod
def download_model(self):
raise NotImplementedError()


class DatasetProvider(ABC):
@abstractmethod
def load_config(self):
raise NotImplementedError()

@abstractmethod
def download_dataset(self):
raise NotImplementedError()


# Get DataClass config from the environment variables.
# Env names must be equal to the DataClass parameters.
def get_config_from_env(config) -> Dict[str, str]:
config_from_env = {}
for field in fields(config):
config_from_env[field.name] = os.getenv(field.name.upper())

return config_from_env

0 comments on commit 3f7ec16

Please sign in to comment.