diff --git a/src/scida/__init__.py b/src/scida/__init__.py index aae8221f..a6a5dcd9 100644 --- a/src/scida/__init__.py +++ b/src/scida/__init__.py @@ -10,5 +10,6 @@ from scida.customs.gadgetstyle.dataset import GadgetStyleSnapshot, SwiftSnapshot from scida.customs.gizmo.dataset import GizmoSnapshot from scida.customs.gizmo.series import GizmoSimulation +from scida.customs.rockstar.dataset import RockstarCatalog logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/src/scida/configfiles/units/illustris.yaml b/src/scida/configfiles/units/illustris.yaml index fc2b69af..6f87d6ae 100644 --- a/src/scida/configfiles/units/illustris.yaml +++ b/src/scida/configfiles/units/illustris.yaml @@ -41,7 +41,7 @@ fields: GFM_MetalsTagged: GFM_WindDMVelDisp: km / s GFM_WindHostHaloMass: code_mass - InternalEnergy: (km / s)^2 + InternalEnergy: (km / s) ^ 2 InternalEnergyOld: (km / s)^2 Machnumber: MagneticField: (h / a^2) * (code_pressure)^(1/2) diff --git a/src/scida/configfiles/units/rockstar.yaml b/src/scida/configfiles/units/rockstar.yaml new file mode 100644 index 00000000..c292c3ca --- /dev/null +++ b/src/scida/configfiles/units/rockstar.yaml @@ -0,0 +1,64 @@ +# Rockstar source says: +# > chars += fprintf(output, "#Units: Masses in Msun / h\n" +# > "#Units: Positions in Mpc / h (comoving)\n" +# > "#Units: Velocities in km / s (physical, peculiar)\n" +# > "#Units: Halo Distances, Lengths, and Radii in kpc / h (comoving)\n" +# > "#Units: Angular Momenta in (Msun/h) * (Mpc/h) * km/s (physical)\n" +# > "#Units: Spins are dimensionless\n"); +# > if (np) +# > chars += fprintf(output, "#Units: Total energy in (Msun/h)*(km/s)^2" +# > " (physical)\n"" + + +units: + # h: 0.7 hubble factor, needs to be set by snapshot + # a: 1.0 # scale factor, needs to be set by snapshot + ckpc: a * kpc + cMpc: a * Mpc + rockstar_mass: Msun / h + rockstar_length: ckpc / h # positions in: cMpc / h + rockstar_velocity: km / s +fields: + _all: + # accrete.rate ? + # accrete.rate.100Myr ? + # accrete.rate.tdyn ? + # am.phantom ? + # am.progenitor.main ? + # axis.b_div.by_a ? + # axis.c_div.by_a ? + descendant.snapshot: + host.distance: rockstar_length + host.velocity: rockstar_velocity + host.velocity.rad: rockstar_velocity + host.velocity.tan: rockstar_velocity + id: + id.to.index: + infall.first.mass: rockstar_mass + infall.first.snapshot: + infall.first.vel.circ.max: rockstar_velocity + major.merger.snapshot: + mass: rockstar_mass + mass.180m: rockstar_mass + mass.200c: rockstar_mass + mass.200m: rockstar_mass + mass.500c: rockstar_mass + mass.bound: rockstar_mass + mass.half.snapshot: + mass.lowres: rockstar_mass + mass.peak: rockstar_mass + mass.peak.snapshot: + mass.vir: rockstar_mass + position: cMpc / h + position.offset: cMpc / h + progenitor.number: + radius: rockstar_length + scale.radius: rockstar_length + scale.radius.klypin: rockstar_length + spin.bullock: + spin.peebles: + tree.index: + vel.circ.max: rockstar_velocity + vel.circ.peak: rockstar_velocity + vel.std: rockstar_velocity + velocity.offset: rockstar_velocity diff --git a/src/scida/convenience.py b/src/scida/convenience.py index 92fa0a49..0dca0ec5 100644 --- a/src/scida/convenience.py +++ b/src/scida/convenience.py @@ -224,19 +224,6 @@ def load( msg = "Dataset is identified as '%s' via _determine_type." % cls log.debug(msg) - # determine additional mixins not set by class - mixins = [] - if unitfile: - if not units: - units = True - kwargs["unitfile"] = unitfile - - if units: - mixins.append(UnitMixin) - kwargs["units"] = units - - kwargs["overwritecache"] = overwrite - # any identifying metadata? classtype = "dataset" if issubclass(cls, DatasetSeries): @@ -256,6 +243,21 @@ def load( if force_class is not None: cls = force_class + # determine additional mixins not set by class + mixins = [] + if hasattr(cls, "_unitfile"): + unitfile = cls._unitfile + if unitfile: + if not units: + units = True + kwargs["unitfile"] = unitfile + + if units: + mixins.append(UnitMixin) + kwargs["units"] = units + + kwargs["overwritecache"] = overwrite + # we append since unit mixin is added outside of this func right now metadata_raw = dict() if classtype == "dataset": diff --git a/src/scida/customs/rockstar/dataset.py b/src/scida/customs/rockstar/dataset.py index 61019577..bcaeac74 100644 --- a/src/scida/customs/rockstar/dataset.py +++ b/src/scida/customs/rockstar/dataset.py @@ -2,10 +2,13 @@ from typing import Union from scida.interface import Dataset +from scida.interfaces.mixins import CosmologyMixin from scida.io import load_metadata_all -class RockstarCatalog(Dataset): +class RockstarCatalog(CosmologyMixin, Dataset): + _unitfile = "units/rockstar.yaml" + def __init__(self, path, **kwargs) -> None: super().__init__(path, **kwargs) @@ -33,11 +36,7 @@ def validate_path(cls, path: Union[str, os.PathLike], *args, **kwargs) -> bool: if possibly_valid: metadata_raw_all = load_metadata_all(path, **kwargs) - attrs = metadata_raw_all["attrs"] - - # lets use the "feature" that rockstar adds no attributes (?) but uses scalar datasets - if len(attrs) != 0: - return False + # attrs = metadata_raw_all["attrs"] # usually rockstar has some cosmology parameters attached datasets = metadata_raw_all["datasets"] diff --git a/src/scida/helpers_hdf5.py b/src/scida/helpers_hdf5.py index 69277066..99f98fee 100644 --- a/src/scida/helpers_hdf5.py +++ b/src/scida/helpers_hdf5.py @@ -30,7 +30,7 @@ def get_dtype(obj): return None -def walk_group(obj, tree, get_attrs=False): +def walk_group(obj, tree, get_attrs=False, scalar_to_attr=True): if len(tree) == 0: tree.update(**dict(attrs={}, groups=[], datasets=[])) if get_attrs and len(obj.attrs) > 0: @@ -38,6 +38,8 @@ def walk_group(obj, tree, get_attrs=False): if isinstance(obj, (h5py.Dataset, zarr.Array)): dtype = get_dtype(obj) tree["datasets"].append([obj.name, obj.shape, dtype]) + if scalar_to_attr and len(obj.shape) == 0: + tree["attrs"][obj.name] = obj[()] elif isinstance(obj, (h5py.Group, zarr.Group)): tree["groups"].append(obj.name) # continue the walk for k, o in obj.items(): diff --git a/src/scida/interfaces/mixins/cosmology.py b/src/scida/interfaces/mixins/cosmology.py index eb6f421d..d33241ff 100644 --- a/src/scida/interfaces/mixins/cosmology.py +++ b/src/scida/interfaces/mixins/cosmology.py @@ -47,6 +47,7 @@ def get_cosmology_from_rawmetadata(metadata_raw): import astropy.units as u from astropy.cosmology import FlatLambdaCDM + # gadgetstyle aliasdict = dict(h=["HubbleParam"], om0=["Omega0"], ob0=["OmegaBaryon"]) cparams = dict(h=None, om0=None, ob0=None) for grp in ["Parameters", "Header"]: @@ -56,6 +57,14 @@ def get_cosmology_from_rawmetadata(metadata_raw): for alias in aliasdict[p]: cparams[p] = metadata_raw.get("/" + grp, {}).get(alias, cparams[p]) + # rockstar + if cparams["h"] is None and "/cosmology:hubble" in metadata_raw: + cparams["h"] = metadata_raw["/cosmology:hubble"] + if cparams["om0"] is None and "/cosmology:omega_matter" in metadata_raw: + cparams["om0"] = metadata_raw["/cosmology:omega_matter"] + if cparams["ob0"] is None and "/cosmology:omega_baryon" in metadata_raw: + cparams["ob0"] = metadata_raw["/cosmology:omega_baryon"] + h, om0, ob0 = cparams["h"], cparams["om0"], cparams["ob0"] if None in [h, om0]: log.info("Cannot infer cosmology.") diff --git a/src/scida/interfaces/mixins/units.py b/src/scida/interfaces/mixins/units.py index d13b7e92..e7da4621 100644 --- a/src/scida/interfaces/mixins/units.py +++ b/src/scida/interfaces/mixins/units.py @@ -202,7 +202,7 @@ def __init__(self, *args, **kwargs): dsprops = c["data"][self.hints["dsname"]] missing_units = dsprops.get("missing_units", missing_units) unitfile = dsprops.get("unitfile", "") - unitfile = kwargs.pop("unitfile", unitfile) + unitfile = kwargs.pop("unitfile", unitfile) # passed kwarg takes precedence unitdefs = get_config_fromfile("units/general.yaml").get("units", {}) if unitfile != "": unithints = get_config_fromfile(unitfile) diff --git a/tests/customs/test_fire2.py b/tests/customs/test_fire2.py index 011db971..584e4b31 100644 --- a/tests/customs/test_fire2.py +++ b/tests/customs/test_fire2.py @@ -17,6 +17,23 @@ def test_series(testdatapath): print(ds.info()) check_firesnap(ds) + rh = ds.data["rockstar_halo"] + + print(rh.info()) + for k, v in rh.items(): + print(k, v) + + # Rockstar source says: + # > chars += fprintf(output, "#Units: Masses in Msun / h\n" + # > "#Units: Positions in Mpc / h (comoving)\n" + # > "#Units: Velocities in km / s (physical, peculiar)\n" + # > "#Units: Halo Distances, Lengths, and Radii in kpc / h (comoving)\n" + # > "#Units: Angular Momenta in (Msun/h) * (Mpc/h) * km/s (physical)\n" + # > "#Units: Spins are dimensionless\n"); + # > if (np) + # > chars += fprintf(output, "#Units: Total energy in (Msun/h)*(km/s)^2" + # > " (physical)\n"" + def check_firesnap(obj: GizmoSnapshot): assert isinstance(obj, GizmoSnapshot) @@ -72,4 +89,5 @@ def check_firesnap(obj: GizmoSnapshot): def test_snapshot(testdatapath): """Test loading of a full snapshot""" obj: GizmoSnapshot = load(testdatapath) + print(obj.info()) check_firesnap(obj) diff --git a/tests/customs/test_rockstar.py b/tests/customs/test_rockstar.py index ef63f91f..20124512 100644 --- a/tests/customs/test_rockstar.py +++ b/tests/customs/test_rockstar.py @@ -1,12 +1,13 @@ from scida.convenience import load +from scida.customs.rockstar.dataset import RockstarCatalog from tests.testdata_properties import require_testdata_path @require_testdata_path("rockstar") -def test_load_check_seriestype(testdatapath): +def test_load(testdatapath): # this test intentionally ignores all additional hints from config files # catch_exception = True allows us better control for debugging obj = load(testdatapath) print(obj.info()) print(obj.data.keys()) - pass + assert isinstance(obj, RockstarCatalog)