Skip to content

Commit

Permalink
feat(state): add modified_keys accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
jourdain committed Jan 12, 2025
1 parent 3231702 commit 12733a1
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
85 changes: 85 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,88 @@ def test_dunder():

state.flush()
assert state.to_dict() == {}


@pytest.mark.asyncio
async def test_modified_keys():
"""
0 msg : test_modified_keys
1 push : {'a': 1, 'b': 2, 'c': 3}
2 msg : get initial a,b,c
3 msg : changed should be => a
4 push : {'a': 2}
5 msg : changed ['a']
6 msg : End of flush 1
7 msg : changed should be => a, b
8 push : {'a': 3, 'b': 4}
9 msg : changed ['a', 'b']
10 msg : End of flush 2
11 msg : changed should be => a, b, c
12 push : {'a': 4, 'b': 6, 'c': 6}
13 msg : changed ['a', 'b', 'c']
14 msg : side effect c => a + b
15 push : {'a': 4.5, 'b': 6.5}
16 msg : changed ['a', 'b']
17 msg : End of flush 3
"""
server = FakeServer()
server.add_event("test_modified_keys")
state = State(commit_fn=server._push_state)

NAMES = ["a", "b", "c"]
state.update(
{
"a": 1,
"b": 2,
"c": 3,
}
)
state.ready()
server.add_event("get initial a,b,c")
await asyncio.sleep(0.01)

@state.change(*NAMES)
def on_change(**_):
m_keys = list(state.modified_keys)
m_keys.sort()
server.add_event(f"changed {m_keys}")

@state.change("c")
def trigger_side_effect(**_):
server.add_event("side effect c => a + b")
state.a += 0.5
state.b += 0.5

with state:
state.a += 1
server.add_event("changed should be => a")

# yield
await asyncio.sleep(0.01)
server.add_event("End of flush 1")

with state:
state.a += 1
state.b += 2
server.add_event("changed should be => a, b")

# yield
await asyncio.sleep(0.01)
server.add_event("End of flush 2")

with state:
state.a += 1
state.b += 2
state.c += 3
server.add_event("changed should be => a, b, c")

# yield
await asyncio.sleep(0.1)
server.add_event("End of flush 3")

result = [line.strip() for line in str(server).split("\n")]
expected = [line.strip() for line in str(test_modified_keys.__doc__).split("\n")]

print(result)

assert expected == result
38 changes: 38 additions & 0 deletions trame_server/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self._push_state_fn = commit_fn
self._hot_reload = hot_reload
self._translator = translator if translator else Translator()
self._modified_keys = share(internal, "_modified_keys", set())
self._change_callbacks = share(internal, "_change_callbacks", {})
self._pending_update = share(internal, "_pending_update", {})
self._pushed_state = share(internal, "_pushed_state", {})
Expand Down Expand Up @@ -220,6 +221,39 @@ def update(self, _dict):
if _dict[key] == self._pushed_state.get(key, TRAME_NON_INIT_VALUE):
self._pending_update.pop(key, None)

@property
def modified_keys(self):
"""
Return the set of state's keys that are modified
for the current state.change update.
Usage example:
--------------
>>> NAMES = ["a", "b", "c"]
>>> state.update({"a": 1, "b": 2, "c": 3})
>>> @state.change(*NAMES)
... def on_change(*_):
... for name in state.modified_keys:
... print(f"{name} value updated to {state[name]}")
>>> with state:
... state.a += 1
>>> with state:
... state.a += 1
... state.b += 2
>>> with state:
... state.a += 1
... state.b += 2
... state.c += 3
"""
# for child server we may need to run the translator on them
return self._modified_keys

def flush(self):
"""Force pushing modified state and execute any @state.change listener"""
if not self.is_ready:
Expand All @@ -232,6 +266,10 @@ def flush(self):
while len(_keys):
keys |= _keys

# update modified keys for current update batch
self._modified_keys.clear()
self._modified_keys |= _keys

# Do the flush
if self._push_state_fn:
self._push_state_fn(self._pending_update)
Expand Down

0 comments on commit 12733a1

Please sign in to comment.