From e404be8d770d0e292665539604d6e0b1619d07a5 Mon Sep 17 00:00:00 2001 From: Chris Byrohl <9221545+cbyrohl@users.noreply.github.com> Date: Sat, 5 Aug 2023 00:01:21 +1000 Subject: [PATCH] Test docs (#65) * fix some doc code * add tests for markdown docs --- CHANGELOG.md | 5 ++ docs/derived_fields.md | 32 +++++++--- docs/faq.md | 8 +-- docs/halocatalogs.md | 2 +- docs/units.md | 7 +-- src/scida/customs/arepo/dataset.py | 2 +- src/scida/customs/gadgetstyle/dataset.py | 2 +- tests/test_docs.py | 79 ++++++++++++++++++++++++ 8 files changed, 119 insertions(+), 18 deletions(-) create mode 100644 tests/test_docs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ecfd0797..78e54cdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- test docs + +### Fixed +- various bugs related to dask operations using pint quantities. ## [0.2.2] - 2023-07-11 diff --git a/docs/derived_fields.md b/docs/derived_fields.md index 5a4044bb..b87b7c9c 100644 --- a/docs/derived_fields.md +++ b/docs/derived_fields.md @@ -9,9 +9,9 @@ There are two ways to create new derived fields. For quick analysis, we can simp ``` py from scida import load -ds = load("somedataset") # (1)! +ds = load("TNG50-4_snapshot") # (1)! gas = ds.data['gas'] -kineticenergy = 0.5*gas['Masses']*gas['Velocities']**2 +kineticenergy = 0.5*gas['Masses']*(gas['Velocities']**2).sum(axis=1) ``` 1. In this example, we assume a dataset, such as the 'TNG50\_snapshot' test data set, that has its fields (*Masses*, *Velocities*) nested by particle type (*gas*) @@ -33,15 +33,16 @@ For this purpose, **field recipes** are available. An example of such recipe is ``` py -import dask.array as da +import numpy as np from scida import load -ds = load("somedataset") +ds = load("TNG50-4_snapshot") -@snap.register_field("stars") # (1)! +@ds.register_field("stars") # (1)! def VelMag(arrs, **kwargs): + import dask.array as da vel = arrs['Velocities'] - return np.sqrt( vel[:,0]**2 + vel[:,1]**2 + vel[:,2]**2 ) + return da.sqrt(vel[:,0]**2 + vel[:,1]**2 + vel[:,2]**2) ``` 1. Here, *stars* is the name of the **field container** the field should be added to. The field will now be available as ds\['stars'\]\['VelMag'\] @@ -69,7 +70,24 @@ def Volume(arrs, **kwargs): return arrs["Masses"]/arrs["Density"] @fielddefs.register_field("all") # (3)! +def GroupDistance3D(arrs, snap=None): + """Returns distance to hosting group center. Returns rubbish if not actually associated with a group.""" + import dask.array as da + boxsize = snap.header["BoxSize"] + pos_part = arrs["Coordinates"] + groupid = arrs["GroupID"] + if hasattr(groupid, "magnitude"): + groupid = groupid.magnitude + boxsize *= snap.ureg("code_length") + pos_cat = snap.data["Group"]["GroupPos"][groupid] + dist3 = pos_part-pos_cat + dist3 = da.where(dist3>boxsize/2.0, boxsize-dist3, dist3) + dist3 = da.where(dist3<=-boxsize/2.0, boxsize+dist3, dist3) # PBC + return dist3 + +@fielddefs.register_field("all") def GroupDistance(arrs, snap=None): + import dask.array as da dist3 = arrs["GroupDistance3D"] dist = da.sqrt((dist3**2).sum(axis=1)) dist = da.where(arrs["GroupID"]==-1, np.nan, dist) # set unbound gas to nan @@ -83,7 +101,7 @@ def GroupDistance(arrs, snap=None): Finally, we just need to import the *fielddefs* object (if we have defined it in another file) and merge them with a dataset that we loaded: ``` py -ds = load("snapshot") +ds = load("TNG50-4_snapshot") ds.data.merge(fielddefs) ``` diff --git a/docs/faq.md b/docs/faq.md index 2540fac2..aa39381e 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -10,9 +10,9 @@ Please note that all fields within a container are expected to have the same sha ``` py from scida import load -import da.array as da -ds = load('simname') -array = da.zeros_like(ds.data["PartType0"]["Density"][:,0]) +import dask.array as da +ds = load('TNG50-4_snapshot') +array = da.zeros_like(ds.data["PartType0"]["Density"]) ds.data['PartType0']["zerofield"] = array ``` @@ -20,7 +20,7 @@ As we operate with dask, make sure to cast your array accordingly. For example, Alternatively, if you have another dataset loaded, you can assign fields from one to another: ``` py -ds2 = load('simname2') +ds2 = load('TNG50-4_snapshot') ds.data['PartType0']["NewDensity"] = ds2.data['PartType0']["Density"] ``` diff --git a/docs/halocatalogs.md b/docs/halocatalogs.md index 6449adbc..79ff14f4 100644 --- a/docs/halocatalogs.md +++ b/docs/halocatalogs.md @@ -7,7 +7,7 @@ Currently, we support the usual FOF/Subfind combination and format. Their presen ``` py from scida import load -ds = load("snapshot") # (1)! +ds = load("TNG50-4_snapshot") # (1)! ``` 1. In this example, we assume a dataset, such as the 'TNG50\_snapshot' test data set, that has its fields (*Masses*, *Velocities*) nested by particle type (*gas*) diff --git a/docs/units.md b/docs/units.md index b7bef659..3abd0ada 100644 --- a/docs/units.md +++ b/docs/units.md @@ -2,7 +2,7 @@ ## Loading data with units -Loading data sets with +Loading data sets with ``` py from scida import load @@ -16,6 +16,7 @@ Units are introduced via the [pint](https://pint.readthedocs.io/en/stable/) pack ``` pycon +>>> gas = ds.data["PartType0"] >>> gas["Coordinates"] dask.array ``` @@ -34,7 +35,7 @@ We can change units for evaluation as desired: ``` pycon >>> coords = gas["Coordinates"] ->>> coords.to("cm") +>>> coords.to("cm") >>> # here the default system is cgs, thus we get the same result from >>> coords.to_base_units() dask.array @@ -76,5 +77,3 @@ dask.array>> coords.to("halfmeter")[0].compute() array([6.64064027e+23, 2.23858253e+24, 1.94176712e+24]) ``` - - diff --git a/src/scida/customs/arepo/dataset.py b/src/scida/customs/arepo/dataset.py index 31da16ae..a818fca6 100644 --- a/src/scida/customs/arepo/dataset.py +++ b/src/scida/customs/arepo/dataset.py @@ -237,7 +237,7 @@ def discover_catalog(self): self.catalog = candidate break - def register_field(self, parttype: str, name: str = None, construct: bool = True): + def register_field(self, parttype: str, name: str = None, construct: bool = False): """ Register a field. Parameters diff --git a/src/scida/customs/gadgetstyle/dataset.py b/src/scida/customs/gadgetstyle/dataset.py index 7eb88aba..1593e903 100644 --- a/src/scida/customs/gadgetstyle/dataset.py +++ b/src/scida/customs/gadgetstyle/dataset.py @@ -16,7 +16,7 @@ class GadgetStyleSnapshot(Dataset): def __init__(self, path, chunksize="auto", virtualcache=True, **kwargs) -> None: """We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow the common /PartType0, /PartType1 grouping scheme.""" - self.boxsize = np.full(3, np.nan) + self.boxsize = np.nan super().__init__(path, chunksize=chunksize, virtualcache=virtualcache, **kwargs) defaultattributes = ["config", "header", "parameters"] diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 00000000..423b6984 --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,79 @@ +import pathlib + +import pytest + +# unfortunately cannot use existing solution, so have to write our own +# (https://github.com/nschloe/pytest-codeblocks/issues/77) + +ignore_files = ["largedatasets.md", "visualization.md"] + + +class DocFile: + def __init__(self, path: pathlib.Path): + self.path = path + self.lines = [] + self.codeblocks = [] + self.extract_codeblocks() + + def extract_codeblocks(self): + with open(self.path, "r") as f: + lines = f.readlines() + # check line for codeblock start/end + lidx = [i for i, k in enumerate(lines) if k.startswith("```")] + print(lidx) + # assert that there are an even number of codeblock start/end + assert len(lidx) % 2 == 0 + cblines = [] + for i in range(0, len(lidx), 2): + start = lidx[i] # includes codeblock start line + end = lidx[i + 1] + # we check for the type. if its 'pycon', we need to remove the '>>>' prompt + blocktype = lines[start].strip().replace("```", "").strip() + blocktype = blocktype.split()[0].strip() + if blocktype == "py" or blocktype == "python": + cblines.append(lines[start + 1 : end]) + elif blocktype == "pycon": + cblines.append( + [k[4:] for k in lines[start + 1 : end] if k.startswith(">>>")] + ) + elif blocktype == "bash": + # not python; ignore + pass + else: + raise ValueError("Unknown codeblock type: %s" % blocktype) + self.codeblocks = ["".join(cbl) for cbl in cblines] + + def evaluate(self): + # evaluate all at once (for now) + code = "\n".join(self.codeblocks) + print("Evaluating code:") + print(code) + exec(code) + + +def get_docfiles(): + fixturename = "mddocfile" + docfiles = [] + paths = [] + ids = [] + + docpath = pathlib.Path(__file__).parent.parent / "docs" + for p in docpath.glob("*.md"): + name = p.name + if name in ignore_files: + continue + print("Evaluating %s" % p) + # read lines + docfile = DocFile(p) + docfiles.append(docfile) + paths.append(p) + ids.append(name) + + params = docfiles + return pytest.mark.parametrize(fixturename, params, ids=ids) + + +@pytest.mark.xfail +@get_docfiles() +def test_docs(mddocfile): + mddocfile.evaluate()