Skip to content

Commit

Permalink
Merge pull request nugraph#3 from wkliao/scalars
Browse files Browse the repository at this point in the history
adjust for scalar variables
  • Loading branch information
vhewes authored Oct 27, 2021
2 parents 3057b47 + 454aaed commit 90a67ab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
5 changes: 4 additions & 1 deletion numl/core/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, fname):
def save(self, obj, name):
for key, val in obj:
# set chunk sizes to val shape, so there is only one chunk per dataset
self.f.create_dataset(f"/{name}/{key}", data=val, chunks=val.shape, compression="gzip")
if (isinstance(val, torch.Tensor)) :
self.f.create_dataset(f"/{name}/{key}", data=val, chunks=val.shape, compression="gzip")
else:
self.f.create_dataset(f"/{name}/{key}", data=val)
# below is to disable data compression (and chunking)
# self.f.create_dataset(f"/{name}/{key}", data=val)

Expand Down
12 changes: 9 additions & 3 deletions numl/process/hitgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import sys

edep1_t = 0.0
edep2_t = 0.0
hit_merge_t = 0.0
torch_t = 0.0
plane_t = 0.0
Expand Down Expand Up @@ -312,8 +314,12 @@ def process_file(out, fname, g=process_event, l=standard.semantic_label,
if profiling:
grp_size = 0
for key, val in data:
# calculate size in bytes of val, a pytorch tensor
grp_size += val.element_size() * val.nelement()
# calculate size in bytes of val
if (isinstance(val, torch.Tensor)):
# val is a pytorch tensor
grp_size += val.element_size() * val.nelement()
else:
grp_size += sys.getsizeof(val)
num_grps += 1
grp_size_sum += grp_size
if grp_size > grp_size_max : grp_size_max = grp_size
Expand Down Expand Up @@ -406,4 +412,4 @@ def process_file(out, fname, g=process_event, l=standard.semantic_label,
print("graph creation time MAX=%8.2f MIN=%8.2f" % (max_total_t[2], min_total_t[2]))
print("write to files time MAX=%8.2f MIN=%8.2f" % (max_total_t[3], min_total_t[3]))
print("total time MAX=%8.2f MIN=%8.2f" % (max_total_t[4], min_total_t[4]))
print("(MAX and MIN timings are among %d processes)" % nprocs)
print("(MAX and MIN timings are among %d processes)" % nprocs)

0 comments on commit 90a67ab

Please sign in to comment.