Skip to content

Commit

Permalink
Fix FX tracing issues for Llama (huggingface#30619)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored May 2, 2024
1 parent 9719202 commit 39359e5
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy):
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""

def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
self._orig_cache_cls = orig_cache_cls

@property
def __class__(self):
return ProxyableCache
if not hasattr(self, "_orig_cache_cls"):
raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]


def create_wrapper(
Expand Down Expand Up @@ -806,23 +811,39 @@ def _proxies_to_metas(v):
return v


def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
return HFCacheProxy(n, _CURRENT_TRACER)
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy

return cache_proxy_factory_fn


# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
)
ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableDynamicCache",
(DynamicCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableStaticCache",
(StaticCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)


Expand Down

0 comments on commit 39359e5

Please sign in to comment.