-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add kfp-tensorflow notebook to confirm NVIDIA GPU access (#139)
Add `kfp-tensorflow` notebook in order to confirm access to a GPU. The notebook spins uses kfp SDK to create an experiment and a run that succeeds when: * the run's pod is scheduled on a node with an NVIDIA GPU * the run's code, and more specifically Tensorflow framework, has access to an NVIDIA GPU. Closes #128
- Loading branch information
Showing
4 changed files
with
317 additions
and
21 deletions.
There are no files selected for viewing
205 changes: 205 additions & 0 deletions
205
tests/notebooks/gpu/kfp-tensorflow/kfp-tensorflow-integration.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Test KFP Integration\n", | ||
"\n", | ||
"- create an experiment\n", | ||
"- create a run\n", | ||
"- check that the run passes. This happens only when both of the following are true:\n", | ||
" * the run's pod is scheduled on a node with an NVIDIA GPU\n", | ||
" * the code, and more specifically Tensorflow framework, has access to a GPU" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Please check the requirements.in file for more details\n", | ||
"!pip install -r requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import kfp\n", | ||
"import os\n", | ||
"\n", | ||
"from kfp import dsl\n", | ||
"from tenacity import retry, stop_after_attempt, wait_exponential" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"client = kfp.Client()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"EXPERIMENT_NAME = 'Check access to GPU'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"HTTP_PROXY = HTTPS_PROXY = NO_PROXY = None\n", | ||
"\n", | ||
"if os.environ.get('HTTP_PROXY') and os.environ.get('HTTPS_PROXY') and os.environ.get('NO_PROXY'):\n", | ||
" HTTP_PROXY = os.environ['HTTP_PROXY']\n", | ||
" HTTPS_PROXY = os.environ['HTTPS_PROXY']\n", | ||
" NO_PROXY = os.environ['NO_PROXY']\n", | ||
"\n", | ||
"def add_proxy(task: dsl.PipelineTask, http_proxy=HTTP_PROXY, https_proxy=HTTPS_PROXY, no_proxy=NO_PROXY) -> dsl.PipelineTask:\n", | ||
" \"\"\"Adds the proxy env vars to the PipelineTask object.\"\"\"\n", | ||
" return (\n", | ||
" task.set_env_variable(name='http_proxy', value=http_proxy)\n", | ||
" .set_env_variable(name='https_proxy', value=https_proxy)\n", | ||
" .set_env_variable(name='HTTP_PROXY', value=http_proxy)\n", | ||
" .set_env_variable(name='HTTPS_PROXY', value=https_proxy)\n", | ||
" .set_env_variable(name='no_proxy', value=no_proxy)\n", | ||
" .set_env_variable(name='NO_PROXY', value=no_proxy)\n", | ||
" )\n", | ||
"\n", | ||
"def proxy_envs_set():\n", | ||
" \"\"\"Check if the proxy env vars are set\"\"\"\n", | ||
" if HTTP_PROXY and HTTPS_PROXY and NO_PROXY:\n", | ||
" return True\n", | ||
" return False" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@dsl.component(base_image=\"kubeflownotebookswg/jupyter-tensorflow-cuda:v1.9.0\")\n", | ||
"def gpu_check() -> str:\n", | ||
" \"\"\"Check access to a GPU.\"\"\"\n", | ||
" import tensorflow as tf\n", | ||
"\n", | ||
" gpus = tf.config.list_physical_devices('GPU')\n", | ||
" print(\"GPU list:\", gpus)\n", | ||
" if not gpus:\n", | ||
" raise RuntimeError(\"No GPU has been detected.\")\n", | ||
" return str(len(gpus)>0)\n", | ||
"\n", | ||
"def add_gpu_request(task: dsl.PipelineTask) -> dsl.PipelineTask:\n", | ||
" \"\"\"Add a request field for a GPU to the container created by the PipelineTask object.\"\"\"\n", | ||
" return ( task.add_node_selector_constraint(accelerator = \"nvidia.com/gpu\").set_accelerator_limit(limit = 1) )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@dsl.pipeline\n", | ||
"def gpu_check_pipeline() -> str:\n", | ||
" \"\"\"Create a pipeline that runs code to check access to a GPU.\"\"\"\n", | ||
" gpu_check1 = add_gpu_request(gpu_check())\n", | ||
" return gpu_check1.output\n", | ||
"\n", | ||
"@dsl.pipeline\n", | ||
"def gpu_check_pipeline_proxy() -> str:\n", | ||
" \"\"\"Create a pipeline that runs code to check access to a GPU and sets the appropriate proxy ENV variables.\"\"\"\n", | ||
" gpu_check1 = add_proxy(add_gpu_request(gpu_check()))\n", | ||
" return gpu_check1.output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setting enable_caching to False to overcome https://github.com/canonical/bundle-kubeflow/issues/1067\n", | ||
"if proxy_envs_set():\n", | ||
" run = client.create_run_from_pipeline_func(\n", | ||
" gpu_check_pipeline_proxy,\n", | ||
" experiment_name=EXPERIMENT_NAME,\n", | ||
" enable_caching=False,\n", | ||
" )\n", | ||
"else:\n", | ||
" run = client.create_run_from_pipeline_func(\n", | ||
" gpu_check_pipeline,\n", | ||
" experiment_name=EXPERIMENT_NAME,\n", | ||
" enable_caching=False,\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"client.list_experiments().experiments" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"client.get_run(run.run_id).state" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@retry(\n", | ||
" wait=wait_exponential(multiplier=2, min=1, max=10),\n", | ||
" stop=stop_after_attempt(30),\n", | ||
" reraise=True,\n", | ||
")\n", | ||
"def assert_run_succeeded(client, run_id):\n", | ||
" \"\"\"Wait for the run to complete successfully.\"\"\"\n", | ||
" status = client.get_run(run_id).state\n", | ||
" assert status == \"SUCCEEDED\", f\"KFP run in {status} state.\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# fetch KFP experiment to ensure it exists\n", | ||
"client.get_experiment(experiment_name=EXPERIMENT_NAME)\n", | ||
"\n", | ||
"assert_run_succeeded(client, run.run_id)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
kfp>=2.4,<3.0 | ||
tenacity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# | ||
# This file is autogenerated by pip-compile with Python 3.11 | ||
# by the following command: | ||
# | ||
# pip-compile requirements.in | ||
# | ||
cachetools==5.5.0 | ||
# via google-auth | ||
certifi==2024.8.30 | ||
# via | ||
# kfp-server-api | ||
# kubernetes | ||
# requests | ||
charset-normalizer==3.4.0 | ||
# via requests | ||
click==8.1.7 | ||
# via kfp | ||
docstring-parser==0.16 | ||
# via kfp | ||
google-api-core==2.23.0 | ||
# via | ||
# google-cloud-core | ||
# google-cloud-storage | ||
# kfp | ||
google-auth==2.36.0 | ||
# via | ||
# google-api-core | ||
# google-cloud-core | ||
# google-cloud-storage | ||
# kfp | ||
# kubernetes | ||
google-cloud-core==2.4.1 | ||
# via google-cloud-storage | ||
google-cloud-storage==2.18.2 | ||
# via kfp | ||
google-crc32c==1.6.0 | ||
# via | ||
# google-cloud-storage | ||
# google-resumable-media | ||
google-resumable-media==2.7.2 | ||
# via google-cloud-storage | ||
googleapis-common-protos==1.66.0 | ||
# via google-api-core | ||
idna==3.10 | ||
# via requests | ||
kfp==2.10.1 | ||
# via -r requirements.in | ||
kfp-pipeline-spec==0.5.0 | ||
# via kfp | ||
kfp-server-api==2.3.0 | ||
# via kfp | ||
kubernetes==30.1.0 | ||
# via kfp | ||
oauthlib==3.2.2 | ||
# via | ||
# kubernetes | ||
# requests-oauthlib | ||
proto-plus==1.25.0 | ||
# via google-api-core | ||
protobuf==4.25.5 | ||
# via | ||
# google-api-core | ||
# googleapis-common-protos | ||
# kfp | ||
# kfp-pipeline-spec | ||
# proto-plus | ||
pyasn1==0.6.1 | ||
# via | ||
# pyasn1-modules | ||
# rsa | ||
pyasn1-modules==0.4.1 | ||
# via google-auth | ||
python-dateutil==2.9.0.post0 | ||
# via | ||
# kfp-server-api | ||
# kubernetes | ||
pyyaml==6.0.2 | ||
# via | ||
# kfp | ||
# kubernetes | ||
requests==2.32.3 | ||
# via | ||
# google-api-core | ||
# google-cloud-storage | ||
# kubernetes | ||
# requests-oauthlib | ||
# requests-toolbelt | ||
requests-oauthlib==2.0.0 | ||
# via kubernetes | ||
requests-toolbelt==0.10.1 | ||
# via kfp | ||
rsa==4.9 | ||
# via google-auth | ||
six==1.16.0 | ||
# via | ||
# kfp-server-api | ||
# kubernetes | ||
# python-dateutil | ||
tabulate==0.9.0 | ||
# via kfp | ||
tenacity==9.0.0 | ||
# via -r requirements.in | ||
urllib3==1.26.20 | ||
# via | ||
# kfp | ||
# kfp-server-api | ||
# kubernetes | ||
# requests | ||
websocket-client==1.8.0 | ||
# via kubernetes |
This file was deleted.
Oops, something went wrong.