Skip to content

Commit

Permalink
Cast only if number has same value
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Sep 28, 2022
1 parent 8df97d6 commit edf06b7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
17 changes: 13 additions & 4 deletions jupyter_ydoc/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from typing import Dict, List, Union
from typing import Dict, List, Type, Union

INT = Type[int]
FLOAT = Type[float]

def cast_all(o: Union[List, Dict], from_type, to_type) -> Union[List, Dict]:

def cast_all(
o: Union[List, Dict], from_type: Union[INT, FLOAT], to_type: Union[FLOAT, INT]
) -> Union[List, Dict]:
if isinstance(o, list):
for i, v in enumerate(o):
if type(v) == from_type:
o[i] = to_type(v)
v2 = to_type(v)
if v == v2:
o[i] = v2
elif isinstance(v, (list, dict)):
cast_all(v, from_type, to_type)
elif isinstance(o, dict):
for k, v in o.items():
if type(v) == from_type:
o[k] = to_type(v)
v2 = to_type(v)
if v == v2:
o[k] = v2
elif isinstance(v, (list, dict)):
cast_all(v, from_type, to_type)
return o
6 changes: 3 additions & 3 deletions jupyter_ydoc/ydoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs):
def get_cell(self, index: int) -> Dict[str, Any]:
meta = self._ymeta.to_json()
cell = self._ycells[index].to_json()
cast_all(cell, float, int)
cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4:
# strip cell IDs if we have notebook format 4.0-4.4
del cell["id"]
Expand Down Expand Up @@ -136,7 +136,7 @@ def set_ycell(self, index: int, ycell: Y.YMap, txn=None):

def get(self):
meta = self._ymeta.to_json()
cast_all(meta, float, int)
cast_all(meta, float, int) # notebook coming from Yjs has e.g. nbformat as float
cells = []
for i in range(len(self._ycells)):
cell = self.get_cell(i)
Expand All @@ -161,7 +161,7 @@ def get(self):
def set(self, value):
nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"}
nb = copy.deepcopy(nb_without_cells)
cast_all(nb, int, float)
cast_all(nb, int, float) # Yjs expects numbers to be floating numbers
cells = value["cells"] or [
{
"cell_type": "code",
Expand Down

0 comments on commit edf06b7

Please sign in to comment.