Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Create model and dataset initializers #2303

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/publish-core-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ jobs:
dockerfile: cmd/training-operator.v2alpha1/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
tag-prefix: v2alpha1
- component-name: model-initializer-v2
dockerfile: cmd/initializer_v2/model/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: dataset-initializer-v2
dockerfile: cmd/initializer_v2/dataset/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: kubectl-delivery
dockerfile: build/images/kubectl-delivery/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ cover.out
.vscode/
__debug_bin

# Compiled python files.
*.pyc
# Python cache files
__pycache__/

# Emacs temporary files
*~
Expand Down
13 changes: 13 additions & 0 deletions cmd/initializer_v2/dataset/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine
Copy link
Member

@tenzen-y tenzen-y Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
FROM python:3.11-alpine
FROM python:3.11-slim-bookworm

@andreyvelich Could you use the Debian image since the Alpine has a performance penalty due to with musl libc?
Python still depends on the C codes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me create an issue


WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initializer_v2/dataset/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initializer_v2 pkg/initializer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initializer_v2.dataset"]
1 change: 1 addition & 0 deletions cmd/initializer_v2/dataset/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
13 changes: 13 additions & 0 deletions cmd/initializer_v2/model/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine

WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initializer_v2/model/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initializer_v2 pkg/initializer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initializer_v2.model"]
1 change: 1 addition & 0 deletions cmd/initializer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
31 changes: 31 additions & 0 deletions pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.dataset.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting dataset initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

match urlparse(storage_uri).scheme:
# TODO (andreyvelich): Implement more dataset providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_dataset()
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception
9 changes: 9 additions & 0 deletions pkg/initializer_v2/dataset/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceDatasetConfig:
storage_uri: str
access_token: Optional[str] = None
42 changes: 42 additions & 0 deletions pkg/initializer_v2/dataset/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from urllib.parse import urlparse

import huggingface_hub

import pkg.initializer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initializer_v2.dataset.config import HuggingFaceDatasetConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace(utils.DatasetProvider):

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceDatasetConfig)
logging.info(f"Config for HuggingFace dataset initializer: {config_dict}")
self.config = HuggingFaceDatasetConfig(**config_dict)

def download_dataset(self):
storage_uri_parsed = urlparse(self.config.storage_uri)
dataset_uri = storage_uri_parsed.netloc + storage_uri_parsed.path

logging.info(f"Downloading dataset: {dataset_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

huggingface_hub.snapshot_download(
repo_id=dataset_uri,
repo_type="dataset",
local_dir=constants.VOLUME_PATH_DATASET,
)

logging.info("Dataset has been downloaded")
33 changes: 33 additions & 0 deletions pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.model.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting pre-trained model initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

match urlparse(storage_uri).scheme:
# TODO (andreyvelich): Implement more model providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_model()
case _:
logging.error(
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception
9 changes: 9 additions & 0 deletions pkg/initializer_v2/model/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceModelInputConfig:
storage_uri: str
access_token: Optional[str] = None
47 changes: 47 additions & 0 deletions pkg/initializer_v2/model/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from urllib.parse import urlparse

import huggingface_hub

import pkg.initializer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initializer_v2.model.config import HuggingFaceModelInputConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace(utils.ModelProvider):

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceModelInputConfig)
logging.info(f"Config for HuggingFace model initializer: {config_dict}")
self.config = HuggingFaceModelInputConfig(**config_dict)

def download_model(self):
storage_uri_parsed = urlparse(self.config.storage_uri)
model_uri = storage_uri_parsed.netloc + storage_uri_parsed.path

logging.info(f"Downloading model: {model_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

# TODO (andreyvelich): We should consider to follow vLLM approach with allow patterns.
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815913663
# TODO (andreyvelich): We should update patterns for Mistral model
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815914270
huggingface_hub.snapshot_download(
repo_id=model_uri,
local_dir=constants.VOLUME_PATH_MODEL,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
)

logging.info("Model has been downloaded")
Empty file.
37 changes: 37 additions & 0 deletions pkg/initializer_v2/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
from abc import ABC, abstractmethod
from dataclasses import fields
from typing import Dict

STORAGE_URI_ENV = "STORAGE_URI"
HF_SCHEME = "hf"


class ModelProvider(ABC):
@abstractmethod
def load_config(self):
raise NotImplementedError()

@abstractmethod
def download_model(self):
raise NotImplementedError()


class DatasetProvider(ABC):
@abstractmethod
def load_config(self):
raise NotImplementedError()

@abstractmethod
def download_dataset(self):
raise NotImplementedError()


# Get DataClass config from the environment variables.
# Env names must be equal to the DataClass parameters.
def get_config_from_env(config) -> Dict[str, str]:
config_from_env = {}
for field in fields(config):
config_from_env[field.name] = os.getenv(field.name.upper())

return config_from_env
Loading