diff --git a/docs/usage.ipynb b/docs/usage.ipynb index eaf69a13..1483d7f5 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -189,7 +189,7 @@ " expression=model.expression,\n", " parameters=model.parameter_defaults,\n", ")\n", - "intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")\n", + "intensity = LambdifiedFunction(sympy_model, backend=\"jax\")\n", "data_converter = HelicityTransformer(model.adapter)\n", "phsp_sample = generate_phsp(100_000, model.adapter.reaction_info)\n", "data_sample = generate_data(\n", @@ -246,7 +246,7 @@ "source": [ "phsp_set = data_converter.transform(phsp_sample)\n", "data_set = data_converter.transform(data_sample)\n", - "data_frame = pd.DataFrame(data_set.to_pandas())\n", + "data_frame = pd.DataFrame(data_set)\n", "data_frame[\"m_12\"].hist(bins=100, alpha=0.5, density=True)\n", "indicate_masses()\n", "plt.legend();" diff --git a/docs/usage/step2.ipynb b/docs/usage/step2.ipynb index 8e511d30..4ef1278a 100644 --- a/docs/usage/step2.ipynb +++ b/docs/usage/step2.ipynb @@ -277,8 +277,8 @@ "import numpy as np\n", "import pandas as pd\n", "\n", - "data_frame = pd.DataFrame(data_set.to_pandas())\n", - "phsp_frame = pd.DataFrame(data_set.to_pandas())\n", + "data_frame = pd.DataFrame(data_set)\n", + "phsp_frame = pd.DataFrame(data_set)\n", "data_frame" ] }, diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index 49d9f8cf..7c1078c7 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -1,5 +1,6 @@ """Implementations of `.DataTransformer`.""" +import numpy as np from ampform.kinematics import EventCollection, HelicityAdapter from tensorwaves.interfaces import DataSample, DataTransformer @@ -17,4 +18,5 @@ def __init__(self, helicity_adapter: HelicityAdapter) -> None: def transform(self, dataset: DataSample) -> DataSample: events = EventCollection({int(k): v for k, v in dataset.items()}) - return self.__helicity_adapter.transform(events) + dataset = self.__helicity_adapter.transform(events) + return {key: np.array(values) for key, values in dataset.items()}