-
Notifications
You must be signed in to change notification settings - Fork 3
/
tf_model_download.py
84 lines (73 loc) · 2.82 KB
/
tf_model_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Copyright 2022 Bell Eapen
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from copy import deepcopy
from typing import Any, Dict
from kedro.io.core import AbstractDataSet
import tensorflow_hub as hub
import tensorflow as tf
# https://gist.github.com/dermatologist/062c46eafe8c118334a004f6cfab663d
class TfModelDownload(AbstractDataSet):
"""This class downloads a BERT model and returns tokenizers and
"""
DEFAULT_LOAD_ARGS = {
"trainable": False,
"height": 224,
"width": 224,
"channels": 3,
} # type: Dict[str, Any]
def __init__(
self,
model_url: str,
load_args: Dict[str, Any] = None,
) -> None:
"""Initialises the class.
Args:
filepath: The path to the file where the BERT model is saved.
url: The URL from which the BERT model is downloaded.
credentials: Credentials required to access the URL.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
self._model_url = model_url
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
def _load(self) -> Any:
"""Loads the BERT model from the URL and saves it to the specified
location.
"""
trainable = self._load_args.get("trainable", False)
height = self._load_args.get("height", 224)
width = self._load_args.get("width", 224)
channels = self._load_args.get("channels", 3)
m = tf.keras.Sequential([
hub.KerasLayer(self._model_url, trainable=trainable)
])
m.build([None, height, width, channels]) # Batch input shape.
return m
def _save(self, data: Any) -> None:
"""Saves the BERT model to the specified location.
Args:
Not implemented
"""
pass
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset.
Returns:
A dict that describes the attributes of the dataset.
"""
return dict(
model_url=self._model_url,
)