Skip to content

Commit

Permalink
Unified permission support for target formatter dump.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Oct 15, 2024
1 parent 5e3875c commit 976d53b
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 83 deletions.
37 changes: 25 additions & 12 deletions law/contrib/awkward/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from law.target.formatter import Formatter, PickleFormatter
from law.target.file import get_path
from law.logger import get_logger
from law.util import no_value


logger = get_logger(__name__)
Expand Down Expand Up @@ -40,18 +41,24 @@ def load(cls, path, *args, **kwargs):

@classmethod
def dump(cls, path, obj, *args, **kwargs):
path = get_path(path)
_path = get_path(path)
perm = kwargs.pop("perm", no_value)

if path.endswith((".parquet", ".parq")):
if _path.endswith((".parquet", ".parq")):
import awkward as ak
return ak.to_parquet(obj, path, *args, **kwargs)
ret = ak.to_parquet(obj, _path, *args, **kwargs)

if path.endswith(".json"):
elif _path.endswith(".json"):
import awkward as ak
return ak.to_json(obj, path, *args, **kwargs)
ret = ak.to_json(obj, _path, *args, **kwargs)

# .pickle, .pkl
return PickleFormatter.dump(path, obj, *args, **kwargs)
else: # .pickle, .pkl
ret = PickleFormatter.dump(_path, obj, *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret


class DaskAwkwardFormatter(Formatter):
Expand All @@ -78,10 +85,16 @@ def load(cls, path, *args, **kwargs):
def dump(cls, path, obj, *args, **kwargs):
import dask_awkward as dak

path = get_path(path)
_path = get_path(path)
perm = kwargs.pop("perm", no_value)

if path.endswith(".json"):
return dak.to_json(obj, path, *args, **kwargs)
if _path.endswith(".json"):
ret = dak.to_json(obj, _path, *args, **kwargs)

# .parquet, .parq
return dak.to_parquet(obj, path, *args, **kwargs)
else: # .parquet, .parq
ret = dak.to_parquet(obj, _path, *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret
6 changes: 6 additions & 0 deletions law/contrib/coffea/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from law.target.formatter import Formatter
from law.target.file import get_path
from law.logger import get_logger
from law.util import no_value


logger = get_logger(__name__)
Expand Down Expand Up @@ -43,4 +44,9 @@ def load(cls, path, *args, **kwargs):
def dump(cls, path, out, *args, **kwargs):
from coffea.util import save

perm = kwargs.pop("perm", no_value)

save(out, get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)
11 changes: 10 additions & 1 deletion law/contrib/hdf5/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from law.target.formatter import Formatter
from law.target.file import get_path
from law.util import no_value


class H5pyFormatter(Formatter):
Expand All @@ -27,4 +28,12 @@ def load(cls, path, *args, **kwargs):
@classmethod
def dump(cls, path, *args, **kwargs):
import h5py
return h5py.File(get_path(path), "w", *args, **kwargs)

perm = kwargs.pop("perm", no_value)

ret = h5py.File(get_path(path), "w", *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret
72 changes: 45 additions & 27 deletions law/contrib/keras/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from law.target.formatter import Formatter
from law.target.file import get_path
from law.logger import get_logger
from law.util import no_value


logger = get_logger(__name__)
Expand All @@ -23,24 +24,6 @@ class KerasModelFormatter(Formatter):
def accepts(cls, path, mode):
return get_path(path).endswith((".hdf5", ".h5", ".json", ".yaml", ".yml"))

@classmethod
def dump(cls, path, model, *args, **kwargs):
path = get_path(path)

# the method for saving the model depends on the file extension
if path.endswith(".json"):
with open(path, "w") as f:
f.write(model.to_json(*args, **kwargs))
return

if path.endswith((".yml", ".yaml")):
with open(path, "w") as f:
f.write(model.to_yaml(*args, **kwargs))
return

# .hdf5, .h5, bundle
return model.save(path, *args, **kwargs)

@classmethod
def load(cls, path, *args, **kwargs):
import keras
Expand All @@ -59,6 +42,29 @@ def load(cls, path, *args, **kwargs):
# .hdf5, .h5, bundle
return keras.models.load_model(path, *args, **kwargs)

@classmethod
def dump(cls, path, model, *args, **kwargs):
_path = get_path(path)
perm = kwargs.pop("perm", no_value)

# the method for saving the model depends on the file extension
ret = None
if _path.endswith(".json"):
with open(_path, "w") as f:
f.write(model.to_json(*args, **kwargs))

elif _path.endswith((".yml", ".yaml")):
with open(_path, "w") as f:
f.write(model.to_yaml(*args, **kwargs))

else: # .hdf5, .h5, bundle
ret = model.save(_path, *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret


class KerasWeightsFormatter(Formatter):

Expand All @@ -68,14 +74,21 @@ class KerasWeightsFormatter(Formatter):
def accepts(cls, path, mode):
return get_path(path).endswith((".hdf5", ".h5"))

@classmethod
def dump(cls, path, model, *args, **kwargs):
return model.save_weights(get_path(path), *args, **kwargs)

@classmethod
def load(cls, path, model, *args, **kwargs):
return model.load_weights(get_path(path), *args, **kwargs)

@classmethod
def dump(cls, path, model, *args, **kwargs):
perm = kwargs.pop("perm", no_value)

ret = model.save_weights(get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret


class TFKerasModelFormatter(Formatter):

Expand All @@ -86,18 +99,23 @@ def accepts(cls, path, mode):
return False

@classmethod
def dump(cls, path, model, *args, **kwargs):
def load(cls, path, *args, **kwargs):
# deprecation warning until v0.1
logger.warning_once("law.contrib.keras.TFKerasModelFormatter is deprecated, please use "
"law.contrib.tensorflow.TFKerasModelFormatter (named 'tf_keras_model') instead")

model.save(get_path(path), *args, **kwargs)
import tensorflow as tf
return tf.keras.models.load_model(get_path(path), *args, **kwargs)

@classmethod
def load(cls, path, *args, **kwargs):
def dump(cls, path, model, *args, **kwargs):
# deprecation warning until v0.1
logger.warning_once("law.contrib.keras.TFKerasModelFormatter is deprecated, please use "
"law.contrib.tensorflow.TFKerasModelFormatter (named 'tf_keras_model') instead")

import tensorflow as tf
return tf.keras.models.load_model(get_path(path), *args, **kwargs)
perm = kwargs.pop("perm", no_value)

model.save(get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)
10 changes: 7 additions & 3 deletions law/contrib/matplotlib/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


from law.target.formatter import Formatter
from law.target.file import get_path, FileSystemTarget
from law.target.file import get_path
from law.util import no_value


class MatplotlibFormatter(Formatter):
Expand All @@ -22,6 +23,9 @@ def accepts(cls, path, mode):

@classmethod
def dump(cls, path, fig, *args, **kwargs):
perm = kwargs.pop("perm", no_value)

fig.savefig(get_path(path), *args, **kwargs)
if isinstance(path, FileSystemTarget):
path.chmod(path.fs.default_file_perm)

if perm != no_value:
cls.chmod(path, perm)
13 changes: 9 additions & 4 deletions law/contrib/numpy/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from law.target.formatter import Formatter
from law.target.file import get_path
from law.logger import get_logger
from law.util import no_value


logger = get_logger(__name__)
Expand All @@ -35,11 +36,12 @@ def load(cls, path, *args, **kwargs):
def dump(cls, path, *args, **kwargs):
import numpy as np

path = get_path(path)
_path = get_path(path)
perm = kwargs.pop("perm", no_value)

if path.endswith(".txt"):
if _path.endswith(".txt"):
func = np.savetxt
elif path.endswith(".npz"):
elif _path.endswith(".npz"):
compress_flag = "savez_compressed"
compress = False
if compress_flag in kwargs:
Expand All @@ -52,4 +54,7 @@ def dump(cls, path, *args, **kwargs):
else:
func = np.save

func(path, *args, **kwargs)
func(_path, *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)
10 changes: 9 additions & 1 deletion law/contrib/pyarrow/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from law.target.formatter import Formatter
from law.target.file import get_path
from law.logger import get_logger
from law.util import no_value


logger = get_logger(__name__)
Expand Down Expand Up @@ -48,4 +49,11 @@ def load(cls, path, *args, **kwargs):
def dump(cls, path, obj, *args, **kwargs):
import pyarrow.parquet as pq

return pq.write_table(obj, get_path(path), *args, **kwargs)
perm = kwargs.pop("perm", no_value)

ret = pq.write_table(obj, get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret
23 changes: 21 additions & 2 deletions law/contrib/root/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from law.target.formatter import Formatter
from law.target.file import get_path
from law.util import no_value

from law.contrib.root.util import import_ROOT

Expand Down Expand Up @@ -95,7 +96,14 @@ def dump(cls, path, arr, *args, **kwargs):
ROOT = import_ROOT() # noqa: F841
import root_numpy

return root_numpy.array2root(arr, get_path(path), *args, **kwargs)
perm = kwargs.pop("perm", None)

ret = root_numpy.array2root(arr, get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret


class ROOTPandasFormatter(Formatter):
Expand All @@ -119,7 +127,14 @@ def dump(cls, path, df, *args, **kwargs):
# importing root_pandas adds the to_root() method to data frames
import root_pandas # noqa: F401

return df.to_root(get_path(path), *args, **kwargs)
perm = kwargs.pop("perm", None)

ret = df.to_root(get_path(path), *args, **kwargs)

if perm != no_value:
cls.chmod(path, perm)

return ret


class UprootFormatter(Formatter):
Expand Down Expand Up @@ -147,12 +162,16 @@ def dump(cls, path, mode="recreate", **kwargs):
raise ValueError("unknown uproot writing mode: {}".format(mode))
fn = getattr(uproot, mode.lower())

perm = kwargs.pop("perm", None)

# create the file object and yield it
f = fn(get_path(path), **kwargs)
try:
yield f
finally:
try:
f.file.close()
if perm is not None:
cls.chmod(path, perm)
except:
pass
Loading

0 comments on commit 976d53b

Please sign in to comment.