Skip to content

Commit

Permalink
feat: bump tensorflow and tensorflow probability versions (#336)
Browse files Browse the repository at this point in the history
* feat: bump tensorflow and tensorflow probability versions

* chore: update github flows python version

* chore: update github flows python version for pull request
  • Loading branch information
fabclmnt authored May 3, 2024
1 parent b7c05c2 commit e6334f4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/prerelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
id: version
run: echo ::set-output name=value::${GITHUB_REF#refs/*/}

- name: Setup Python 3.8
- name: Setup Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.10'

- name: Install dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Setup Python 3.8
- name: Setup Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.10'

- name: Cache pip
id: cache
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
requests>=2.30, <2.31
requests>=2.28, <2.31
pandas<3
numpy<2
scikit-learn<2
matplotlib<4
tensorflow==2.12.0
tensorflow-probability==0.19.0
tensorflow==2.15.*
tensorflow-probability[tf]
easydict==1.10
pmlb==1.0.*
tqdm<5.0
typeguard==4.0.*
typeguard==4.2.*
pytest==7.4.*
14 changes: 8 additions & 6 deletions src/ydata_synthetic/synthesizers/saving_keras.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import tensorflow.python.keras as tf_keras
from keras import __version__
tf_keras.__version__ = __version__

from tensorflow.keras import Model
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils

def unpack(model, training_config, weights):
restored_model = deserialize(model)
restored_model = tf_keras.layers.deserialize(model)
if training_config is not None:
restored_model.compile(**saving_utils.compile_args_from_training_config(training_config))
restored_model.compile(**tf_keras.saving.saving_utils.compile_args_from_training_config(training_config))
restored_model.set_weights(weights)
return restored_model

def make_keras_picklable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
model_metadata = tf_keras.saving.saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
model = tf_keras.layers.serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))

Expand Down

0 comments on commit e6334f4

Please sign in to comment.