Skip to content

Commit

Permalink
Noadd (#16)
Browse files Browse the repository at this point in the history
* refactor test

* assertions on input
  • Loading branch information
The-Ludwig authored Feb 15, 2023
1 parent c3b7da9 commit 8349658
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
19 changes: 15 additions & 4 deletions panama/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def read_DAT(
if files is not None and glob is not None:
raise TypeError("`file` and `glob` can't both be not None")

if not additional_columns:
if drop_non_particles:
raise ValueError(
"drop_non_particles requires additional_columns to be calculated."
)
if mother_columns:
raise ValueError(
"mother_columns require additional columns to be calculated"
)

if glob is not None:
basepath = Path(glob).parent
files = list(basepath.glob(Path(glob).name))
Expand Down Expand Up @@ -409,10 +419,11 @@ def read_DAT(
) # mother is muon (and in early generation)
)

if drop_mothers:
df_particles.drop(
index=df_particles.query("is_mother == True").index.values, inplace=True
)
if drop_mothers:
df_particles.drop(
index=df_particles.query("particle_description < 0").index.values,
inplace=True,
)

# Numba version...
# df["mother_run_idx"], df["mother_event_idx"], df["mother_particle_idx"] = mother_idx_numba(df.loc[:, "is_mother"].values, df.loc[:, "run_number"].values, df.loc[:, "event_number"].values, df.loc[:, "particle_number"].values)
Expand Down
54 changes: 37 additions & 17 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,48 @@ def test_noparse(test_file_path=Path(__file__).parent / "files" / "DAT000000"):
assert df_np.equals(df)


def test_read_corsia_file(test_file_path=Path(__file__).parent / "files" / "DAT000000"):

df_run, df_event, df = panama.read_DAT(test_file_path, drop_non_particles=False)

with CorsikaParticleFile(test_file_path, parse_blocks=True) as cf:
def check_eq(file, df_run, df_event, particles, skip_mother=False):
with CorsikaParticleFile(file, parse_blocks=True) as cf:
num = 0
for idx, event in enumerate(cf):
assert df_event.iloc[idx]["total_energy"] == event.header["total_energy"]
for particle in event.particles:
if particle["particle_description"] < 0:
if particle["particle_description"] < 0 and skip_mother:
continue
assert df.iloc[num]["px"] == particle["px"]
assert particles.iloc[num]["px"] == particle["px"]
num += 1


def test_noadd(test_file_path=Path(__file__).parent / "files" / "DAT000000"):

try:
df_run, df_event, particles = panama.read_DAT(
test_file_path, drop_non_particles=True, additional_columns=False
)

check_eq(test_file_path, df_run, df_event, particles)
except ValueError as e:
assert "requires" in str(e)

df_run, df_event, particles = panama.read_DAT(
test_file_path, drop_non_particles=False, additional_columns=True
)

check_eq(test_file_path, df_run, df_event, particles, skip_mother=True)


def test_read_corsia_file(test_file_path=Path(__file__).parent / "files" / "DAT000000"):

df_run, df_event, df = panama.read_DAT(test_file_path, drop_non_particles=False)

check_eq(test_file_path, df_run, df_event, df, skip_mother=True)
try:
check_eq(test_file_path, df_run, df_event, df, skip_mother=False)
assert False
except AssertionError:
pass


def test_cli(tmp_path, test_file_path=Path(__file__).parent / "files" / "DAT000000"):

runner = CliRunner()
Expand All @@ -50,17 +77,10 @@ def test_cli(tmp_path, test_file_path=Path(__file__).parent / "files" / "DAT0000
assert result.exit_code == 0

particles = pd.read_hdf(tmp_path / "output.hdf5", "particles")
event_headers = pd.read_hdf(tmp_path / "output.hdf5", "event_header")
event_header = pd.read_hdf(tmp_path / "output.hdf5", "event_header")
run_header = pd.read_hdf(tmp_path / "output.hdf5", "run_header")

with CorsikaParticleFile(test_file_path, parse_blocks=True) as cf:
num = 0
for idx, event in enumerate(cf):
assert (
event_headers.iloc[idx]["total_energy"] == event.header["total_energy"]
)
for particle in event.particles:
assert particles.iloc[num]["px"] == particle["px"]
num += 1
check_eq(test_file_path, run_header, event_header, particles)


def test_spectral_index(test_file_path=Path(__file__).parent / "files" / "DAT*"):
Expand Down

0 comments on commit 8349658

Please sign in to comment.