diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 0faf7e0d6ea956..b19efac1306cc5 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy): Proxy that represents an instance of `transformers.cache_utils.Cache`. """ + def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]): + self._orig_cache_cls = orig_cache_cls + @property def __class__(self): - return ProxyableCache + if not hasattr(self, "_orig_cache_cls"): + raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.") + return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls] def create_wrapper( @@ -806,23 +811,39 @@ def _proxies_to_metas(v): return v -def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: - global _CURRENT_TRACER - if not isinstance(_CURRENT_TRACER, HFTracer): - raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") - return HFCacheProxy(n, _CURRENT_TRACER) +def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: + def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: + global _CURRENT_TRACER + if not isinstance(_CURRENT_TRACER, HFTracer): + raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") + cache_proxy = HFCacheProxy(n, _CURRENT_TRACER) + cache_proxy.install_orig_cache_cls(orig_cache_cls) + return cache_proxy + + return cache_proxy_factory_fn # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. -ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn) +ProxyableCache = HFProxyableClassMeta( + "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache) +) ProxyableDynamicCache = HFProxyableClassMeta( - "ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn + "ProxyableDynamicCache", + (DynamicCache,), + {}, + proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache), ) ProxyableSinkCache = HFProxyableClassMeta( - "ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn + "ProxyableSinkCache", + (SinkCache,), + {}, + proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache), ) ProxyableStaticCache = HFProxyableClassMeta( - "ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn + "ProxyableStaticCache", + (StaticCache,), + {}, + proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), )