Skip to content

Commit

Permalink
Create custom class for no exist attribute in patch_attr
Browse files Browse the repository at this point in the history
Signed-off-by: Daemyung Jang <quic_daemyung@quicinc.com>
  • Loading branch information
quic-daemyung committed Mar 5, 2024
1 parent 8378866 commit 8ddf939
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def patch_attr(obj, attr_name, new_attr)-> _ContextManager:
if attr_name in obj._parameters or attr_name in obj._buffers: # pylint: disable=protected-access
return _patch_param_or_buffer(obj, attr_name, new_attr)

old_attr = getattr(obj, attr_name, None)
class _NullAttribute:
pass

old_attr = getattr(obj, attr_name, _NullAttribute())
action = lambda: setattr(obj, attr_name, new_attr)

def cleanup():
Expand All @@ -127,7 +130,7 @@ def cleanup():
except AttributeError:
pass

if not hasattr(obj, attr_name) and old_attr is not None:
if not hasattr(obj, attr_name) and not isinstance(old_attr, _NullAttribute):
setattr(obj, attr_name, old_attr)

return _ContextManager(action, cleanup)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,8 @@ def test_patch_attr():

replica = conv._replicate_for_data_parallel()
assert replica.forward.__self__ is replica

with patch_attr(conv, 'no_exist_attribute', 1):
assert conv.no_exist_attribute == 1

assert not hasattr(conv, 'no_exist_attribute')

0 comments on commit 8ddf939

Please sign in to comment.