Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FX tracing issues for Llama #30619

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,8 +1256,11 @@ def _update_causal_mask(
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
Comment on lines 1258 to +1263
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main has this change!

elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
Expand Down
41 changes: 31 additions & 10 deletions src/transformers/utils/fx.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could / should be outside transformers "à terme" WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy):
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""

def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
self._orig_cache_cls = orig_cache_cls

@property
def __class__(self):
return ProxyableCache
if not hasattr(self, "_orig_cache_cls"):
raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]


def create_wrapper(
Expand Down Expand Up @@ -806,23 +811,39 @@ def _proxies_to_metas(v):
return v


def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
return HFCacheProxy(n, _CURRENT_TRACER)
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy

return cache_proxy_factory_fn


# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
)
ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableDynamicCache",
(DynamicCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableStaticCache",
(StaticCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)


Expand Down
Loading