From edf06b727ef6884a3e04df95e5d398a58c1a5d9e Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 28 Sep 2022 10:27:03 +0200 Subject: [PATCH] Cast only if number has same value --- jupyter_ydoc/utils.py | 17 +++++++++++++---- jupyter_ydoc/ydoc.py | 6 +++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/jupyter_ydoc/utils.py b/jupyter_ydoc/utils.py index 2e1559a..9e929b1 100644 --- a/jupyter_ydoc/utils.py +++ b/jupyter_ydoc/utils.py @@ -1,17 +1,26 @@ -from typing import Dict, List, Union +from typing import Dict, List, Type, Union +INT = Type[int] +FLOAT = Type[float] -def cast_all(o: Union[List, Dict], from_type, to_type) -> Union[List, Dict]: + +def cast_all( + o: Union[List, Dict], from_type: Union[INT, FLOAT], to_type: Union[FLOAT, INT] +) -> Union[List, Dict]: if isinstance(o, list): for i, v in enumerate(o): if type(v) == from_type: - o[i] = to_type(v) + v2 = to_type(v) + if v == v2: + o[i] = v2 elif isinstance(v, (list, dict)): cast_all(v, from_type, to_type) elif isinstance(o, dict): for k, v in o.items(): if type(v) == from_type: - o[k] = to_type(v) + v2 = to_type(v) + if v == v2: + o[k] = v2 elif isinstance(v, (list, dict)): cast_all(v, from_type, to_type) return o diff --git a/jupyter_ydoc/ydoc.py b/jupyter_ydoc/ydoc.py index f5b8989..8e186fc 100644 --- a/jupyter_ydoc/ydoc.py +++ b/jupyter_ydoc/ydoc.py @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs): def get_cell(self, index: int) -> Dict[str, Any]: meta = self._ymeta.to_json() cell = self._ycells[index].to_json() - cast_all(cell, float, int) + cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: # strip cell IDs if we have notebook format 4.0-4.4 del cell["id"] @@ -136,7 +136,7 @@ def set_ycell(self, index: int, ycell: Y.YMap, txn=None): def get(self): meta = self._ymeta.to_json() - cast_all(meta, float, int) + cast_all(meta, float, int) # notebook coming from Yjs has e.g. nbformat as float cells = [] for i in range(len(self._ycells)): cell = self.get_cell(i) @@ -161,7 +161,7 @@ def get(self): def set(self, value): nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"} nb = copy.deepcopy(nb_without_cells) - cast_all(nb, int, float) + cast_all(nb, int, float) # Yjs expects numbers to be floating numbers cells = value["cells"] or [ { "cell_type": "code",