Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): Implement Registry client #7597

Merged
merged 35 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a3f3529
Implement registry client
chongyouquan Mar 18, 2022
7f47fa2
Update registry client code
chongyouquan Mar 21, 2022
7dbe8d3
Add test skeleton
chongyouquan Mar 25, 2022
03d82ee
Add some tests
chongyouquan Apr 1, 2022
7c13f50
Update code
chongyouquan Apr 1, 2022
e989dfd
add tests
chongyouquan Apr 15, 2022
173a767
update tests
chongyouquan Apr 18, 2022
285d1a4
Merge branch 'kubeflow:master' into registryclient
chongyouquan Apr 25, 2022
5dcaa3d
update tests
chongyouquan Apr 21, 2022
aa2513c
Rename Client -> RegistryClient
chongyouquan Apr 25, 2022
0073e36
Update wrt comments
chongyouquan Apr 25, 2022
4d005bc
add type annotations
chongyouquan Apr 25, 2022
e1e2058
fix renaming in __init__.py
chongyouquan Apr 25, 2022
f97128d
remove unused imports
chongyouquan Apr 25, 2022
fe2b6e3
extract host variable in test
chongyouquan Apr 25, 2022
1f62bc5
format using yapf
chongyouquan Apr 25, 2022
fc5dd3c
remove locals and use arg keywords
chongyouquan Apr 26, 2022
bcebc7f
remove json conversion
chongyouquan Apr 27, 2022
6075f70
fix header
chongyouquan Apr 27, 2022
a2e6b91
write bytes when downloading file
chongyouquan Apr 27, 2022
e9a1838
fix create_tag; fix tests
chongyouquan Apr 27, 2022
0115190
fix request_body for update_tag and create_tag using json.dumps
chongyouquan Apr 27, 2022
e7af3cb
simply return json for delete_tag
chongyouquan Apr 28, 2022
f7d7a09
rename files
chongyouquan May 3, 2022
802a7ab
format files
chongyouquan May 3, 2022
d7e9823
update return types and format double quotes
chongyouquan May 3, 2022
c4e8372
add comments and format files
chongyouquan May 3, 2022
7a0991f
add todos
chongyouquan May 3, 2022
f2858a4
update credentials and change open to use context
chongyouquan May 3, 2022
4861334
format using yapf
chongyouquan May 3, 2022
1314981
move request into context
chongyouquan May 3, 2022
a6b2162
Update comments
chongyouquan May 5, 2022
1131cbc
Update release notes
chongyouquan May 5, 2022
1405b54
Update release notes
chongyouquan May 5, 2022
7441c5a
Merge branch 'master' into registryclient
chongyouquan May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sdk/python/kfp/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kfp.registry.client import Client
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
264 changes: 264 additions & 0 deletions sdk/python/kfp/registry/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for KFP Registry Client."""

import logging
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved

import google.auth
import json
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
import requests
import re
from typing import Any, Optional, List, Tuple
from google.protobuf import json_format
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved

_KNOWN_HOSTS_REGEX = {
"kfp_pkg_dev": r'(^https\:\/\/(?P<location>[\w\-]+)\-kfp\.pkg\.dev\/(?P<project_id>.*)\/(?P<repo_id>.*))',
}

class _SafeDict(dict):
def __missing__(self, key):
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
return '{' + key + '}'

class ApiAuth(requests.auth.AuthBase):
def __init__(self, token):
self.token = token
def __call__(self, request):
request.headers['authorization'] = 'Bearer ' + self.token
return request

class Client:
def __init__(self,
host: str,
auth: Optional[requests.auth.AuthBase] = None
):
self._host = host.rstrip('/')
self._config = self.load_config()
self._known_host_key = ""
for key in _KNOWN_HOSTS_REGEX.keys():
if re.match(_KNOWN_HOSTS_REGEX[key], self._host):
self._known_host_key = key
break
if credentials:
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
self._auth = auth
elif self._is_ar_host():
logger = logging.getLogger('google.auth._default')
logging_warning_filter = utils.LoggingFilter(logging.WARNING)
logger.addFilter(logging_warning_filter)
self._creds, _ = google.auth.default()
logger.removeFilter(logging_warning_filter)

def _request(self, request_url: str, request_body: str = '',
http_request: str = 'get', extra_headers: str = '') -> Any:
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
"""Call the HTTP request"""
if self._is_ar_host():
if not self._auth.token.valid:
self._auth.token.refresh(google.auth.transport.requests.Request())
headers = {
'Content-type': 'application/json',
}

http_request_fn = getattr(requests, http_request)
response = http_request_fn(
url=request_url, data=request_body, headers=headers, auth=self._auth).json()
response.raise_for_status()

return response

def _is_ar_host():
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
return self._known_host_key == "kfp_pkg_dev"
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved

def load_config(self):
config = {}
if self._is_ar_host():
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
repo_resource_format = ''
try:
matched = re.match(_AR_HOST_TEMPLATE_GROUPS, self._host)
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
repo_resource_format = ('projects/'
'{project_id}/locations/{location}/'
'repositories/{repo_id}'.format_map(
_SafeDict(matched.groupdict())))
except AttributeError as err:
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Invalid host URL')
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
registry_endpoint = 'https://artifactregistry.googleapis.com/v1/'
api_endpoint = f'{registry_endpoint}/{repo_resource_format}'
package_endpoint = f'{api_endpoint}/packages'
package_name_endpoint = f'{package_endpoint}/{{package_name}}'
tags_endpoint = f'{package_name_endpoint}/tags'
versions_endpoint = f'{package_name_endpoint}/versions'
config = {
'host': self._host,
'upload_url': self._host,
'download_version_url': f'{self._host}/{{package_name}}/sha256:{{version}}',
'download_tag_url': f'{self._host}/{{package_name}}/{{tag}}',
'get_package_url': f'{package_name_endpoint}',
'list_packages_url': f'{package_endpoint}/',
'delete_package_url': f'{package_name_endpoint}',
'get_tag_url': f'{tags_endpoint}/{{tag}}',
'list_tags_url': f'{tags_endpoint}/',
'delete_tag_url': f'{tags_endpoint}/{{tag}}',
'create_tag_url': f'{tags_endpoint}?tagId={{tag}}',
'update_tag_url': f'{tags_endpoint}/{{tag}}?updateMask=version',
'get_version_url': f'{versions_endpoint}/{{version}}',
'list_versions_url': f'{versions_endpoint}/',
'delete_version_url': f'{versions_endpoint}/{{version}}',
'package_format': f'{repo_resource_format}/packages/{{package_name}}',
'tag_format': f'{repo_resource_format}/packages/{{package_name}}/tags/{{tag}}',
'version_format': f'{repo_resource_format}/packages/{{package_name}}/versions/{{version}}',
}
else:
raise ValueError(f'load_config not implemented for host: {self._host}')
return config

def upload_pipeline(self, file_name: str, tags: Optional[List[str]],
extra_headers: Optional[dict]) -> Tuple[str, str]:
url = self._config['upload_url']
if self._is_ar_host():
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
if not self._auth.token.valid:
self._auth.token.refresh(google.auth.transport.requests.Request())
request_body = {}
if tags:
request_body = {'tags': ','.join(tags)}

files = {'content': open(file_name, 'rb')}
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
response = requests.post(url=url,
data=request_body, headers=extra_headers,
files=files, auth=self._auth).json()
response.raise_for_status()

return response

def _get_download_url(package_name: str,
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
version: Optional[str] = None,
tag: Optional[str] = None) -> str:
if (not version) and (not tag):
raise ValueError('Either version or tag must be specified.')
if version:
if version.startswith('sha256:'):
version = version[len('sha256:'):]
url = self._config['download_version_url'].format(**locals())
if tag:
if version:
logging.info(
'Both version and tag are specified, using version only.')
else:
url = self._config['download_tag_url'].format(**locals())
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
return url

def download_pipeline(self, package_name: str,
version: Optional[str] = None,
tag: Optional[str] = None,
file_name: str = None) -> str:
url = self._get_download_url(package_name, version, tag)
response = self._request(request_url=url)

if not file_name:
file_name = package_name + '_'
if version:
if version.startswith('sha256:'):
file_name += version[len('sha256:'):]
else:
file_name += version
elif tag:
file_name += tag
file_name += '.yaml'

with open(file_name, 'w') as f:
f.write(response.content)

return file_name

def get_package(self, package_name: str) -> dict:
url = self._config['get_package_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def list_packages(self) -> List[dict]:
url = self._config['list_packages_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def delete_package(self, package_name: str) -> bool:
url = self._config['delete_package_url'].format(**locals())
response = self._request(request_url=url, http_request='delete')
response_json = response.json()

return response_json['done']

def get_version(self, package_name: str, version: str) -> dict:
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
url = self._config['get_version_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def list_versions(self, package_name: str) -> List[dict]:
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
url = self._config['list_versions_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def delete_version(self, package_name: str, version: str) -> bool:
url = self._config['delete_version_url'].format(**locals())
response = self._request(request_url=url, http_request='delete')
response_json = response.json()

return response_json['done']

def create_tag(self, package_name: str, version: str, tag: str) -> dict:
url = self._config['update_tag_url'].format(**locals())
new_tag = {
'name' : self._config['tag_resource_format'].format(**locals()),
'version' : self._config['version_resource_format'].format(**locals())
}
response = self._request(
request_url=url,
request_body=new_tag,
http_request='patch'
)

return response.json()

def get_tag(self, package_name: str, tag: str) -> dict:
url = self._config['get_tag_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def update_tag(self, package_name: str, version: str, tag: str) -> dict:
url = self._config['update_tag_url'].format(**locals())
new_tag = {
'name' : self._config['tag_resource_format'].format(**locals()),
'version' : ''
}
response = self._request(
request_url=url,
request_body=new_tag,
http_request='post'
)

return response.json()

def list_tags(self, package_name: str) -> List[dict]:
url = self._config['list_tags_url'].format(**locals())
response = self._request(request_url=url)

return response.json()

def delete_tag(self, package_name: str, tag: str) -> bool:
url = self._config['delete_tag_url'].format(**locals())
response = self._request(request_url=url, http_request='delete')
response_json = response.json()

return response_json['done']
Loading