Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add syncing models utility to ivy #28818

Merged
merged 10 commits into from
Sep 19, 2024
26 changes: 24 additions & 2 deletions ivy/functional/backends/tensorflow/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,35 @@
import re
import os
import tensorflow as tf
import keras
import numpy as np
import functools
from tensorflow.python.util import nest
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union
from typing import (
NamedTuple,
Callable,
Any,
Tuple,
List,
Dict,
Type,
Union,
TYPE_CHECKING,
)
import inspect
from collections import OrderedDict
from packaging.version import parse
import keras

if TYPE_CHECKING:
import torch.nn as nn


if parse(keras.__version__).major > 2:
YushaArif99 marked this conversation as resolved.
Show resolved Hide resolved
KerasVariable = keras.src.backend.Variable
else:
KerasVariable = tf.Variable




def get_assignment_dict():
Expand Down
2 changes: 2 additions & 0 deletions ivy/stateful/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
from .optimizers import *
from . import sequential
from .sequential import *
from . import utilities
from .utilities import sync_models_torch
Loading
Loading