Skip to content

Commit

Permalink
Merge pull request #986 from ToFuProject/Issue983_npint
Browse files Browse the repository at this point in the history
Get rid of `np.int`
  • Loading branch information
Didou09 authored Nov 13, 2024
2 parents e31b3af + de9d317 commit 3a6820d
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 80 deletions.
133 changes: 108 additions & 25 deletions tofu/imas2tofu/_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
raise Exception('imas not available')


# Useful scalar types
_NINT = (np.int32, np.int64)
_INT = (int,) + _NINT
_NFLOAT = (np.float32, np.float64)
_FLOAT = (float,) + _NFLOAT
_NUMB = _INT + _FLOAT


_DSHORT = _defimas2tofu._dshort
_DCOMP = _defimas2tofu._dcomp
_DDUNITS = imas.dd_units.DataDictionaryUnits()
Expand Down Expand Up @@ -248,12 +256,20 @@ def fsig(obj, indt=None, indch=None, stack=None, dcond=dcond):
sig[jj] = sig[jj][ind[0]]

# Conditions for stacking / sqeezing sig
lc = [(stack and nsig > 1 and isinstance(sig[0], np.ndarray)
and all([ss.shape == sig[0].shape for ss in sig[1:]])),
(stack and nsig > 1
and type(sig[0]) in [int, float, np.int_, np.float64, str]),
(stack and nsig == 1
and type(sig) in [np.ndarray, list, tuple])]
lc = [
(
stack and nsig > 1 and isinstance(sig[0], np.ndarray)
and all([ss.shape == sig[0].shape for ss in sig[1:]])
),
(
stack and nsig > 1
and isinstance(sig[0], _NUMB + (str,))
),
(
stack and nsig == 1
and type(sig) in [np.ndarray, list, tuple]
),
]

if lc[0]:
sig = np.atleast_1d(np.squeeze(np.stack(sig)))
Expand Down Expand Up @@ -447,30 +463,95 @@ def _checkformat_getdata_occ(occ, ids, dids=None):


def _checkformat_getdata_indch(indch, nch):
msg = ("Arg indch must be a either:\n"
+ " - None: all channels used\n"
+ " - int: channel to use (index)\n"
+ " - array of int: channels to use (indices)\n"
+ " - array of bool: channels to use (indices)\n")
lc = [indch is None,
isinstance(indch, int),
hasattr(indch, '__iter__') and not isinstance(indch, str)]
if not any(lc):
""" Check index of channels, returns array of int indices
Parameters
----------
indch : None / int / aiterable of int or bool
Input index of channels
nch : int
Max number of channels
Raises
------
Exception
DESCRIPTION.
Returns
-------
indch : np.ndarray of int
Output indices
"""

# -----------------------------
# Initialize error msg
# -----------------------------

msg = (
"Arg indch must be a either:\n"
"\t- None: all channels used\n"
"\t- int: channel to use (index)\n"
"\t- array of int: channels to use (indices)\n"
"\t- array of bool: channels to use (indices)\n"
)

# -----------------------------
# List of acceptable conditions
# -----------------------------

lc0 = [
indch is None,
isinstance(indch, _INT),
hasattr(indch, '__iter__') and not isinstance(indch, str),
]

if not any(lc0):
raise Exception(msg)
if lc[0]:

# ------------------------
# None
# ------------------------

if lc0[0]:
# defaul to all indices
indch = np.arange(0, nch)
elif lc[1] or lc[2]:

# ------------------------
# integer
# ------------------------

elif lc0[1]:
# make numpy int array
indch = np.r_[indch].ravel()
lc = [indch.dtype == np.int, indch.dtype == np.bool]
if not any(lc):

# ------------------------
# iterable: int or bool
# ------------------------

elif lc0[2]:
# make numpy array
indch = np.r_[indch].ravel()

# get dtype
lc1 = ['int' in indch.dtype.name, 'bool' in indch.dtype.name]

if not any(lc1):
raise Exception(msg)
if lc[1]:

if lc1[1]:
# convert from bool to int
indch = np.nonzero(indch)[0]

# safety check
if not np.all((indch >= 0) & (indch < nch)):
msg = ("Some channel indices are out of scope!\n"
+ "\t- nch: {}\n".format(nch)
+ "\t- indch: {}".format(indch))
msg = (
"Some channel indices are out of scope!\n"
f"\t- nch: {nch}\n"
f"\t- indch: {indch}"
)
raise Exception(msg)

return indch


Expand Down Expand Up @@ -513,8 +594,10 @@ def _check_data(data, pos=None, nan=None, isclose=None, empty=None):
# All values larger than 1e30 are default imas values => nan
if nan is True:
for ii in range(0, len(data)):
c0 = (isinstance(data[ii], np.ndarray)
and data[ii].dtype in [np.float64, np.int_])
c0 = (
isinstance(data[ii], np.ndarray)
and data.dtype in _NUMB
)
if c0 is True:
# Make sure to test only non-nan to avoid warning
ind = (~np.isnan(data[ii])).nonzero()
Expand Down
54 changes: 40 additions & 14 deletions tofu/imas2tofu/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""


# Built-ins
import sys
import os
Expand All @@ -18,11 +19,13 @@
import warnings
import traceback


# Standard
import numpy as np
import matplotlib as mpl
import datetime as dtm


# tofu
pfe = os.path.join(os.path.expanduser('~'), '.tofu', '_imas2tofu_def.py')
if os.path.isfile(pfe):
Expand Down Expand Up @@ -50,13 +53,15 @@
from . import _comp_toobjects as _comp_toobjects
from . import _comp_mesh as _comp_mesh


# imas
try:
import imas
from imas import imasdef
except Exception as err:
raise Exception('imas not available')


__all__ = [
'check_units_IMASvsDSHORT',
'MultiIDSLoader',
Expand All @@ -70,6 +75,14 @@
_ROOT = _ROOT[:_ROOT.index('tofu')+len('tofu')]


# Useful scalar types
_NINT = (np.int32, np.int64)
_INT = (int,) + _NINT
_NFLOAT = (np.float32, np.float64)
_FLOAT = (float,) + _NFLOAT
_NUMB = _INT + _FLOAT


#############################################################
# Preliminary units check
#############################################################
Expand Down Expand Up @@ -360,7 +373,7 @@ def _get_diddids(cls, dids, defidd=None):
v = dids[k]

# Check / format occ and deduce nocc
assert type(dids[k]['occ']) in [int, list]
assert isinstance(dids[k]['occ'], _INT + (list,)), type(dids[k]['occ'])
dids[k]['occ'] = np.r_[dids[k]['occ']].astype(int)
dids[k]['nocc'] = dids[k]['occ'].size
v = dids[k]
Expand Down Expand Up @@ -935,7 +948,9 @@ def _checkformat_idd(
defidd = cls._defidd

if lc[0]:
assert type(shot) in [int,np.int_]
# check type is int or numpy int
assert isinstance(shot, _INT), type(shot)

params = dict(
shot=int(shot), run=run, refshot=refshot, refrun=refrun,
user=user, database=database, version=version,
Expand Down Expand Up @@ -1201,8 +1216,8 @@ def _checkformat_ids(

if occ is None:
occ = 0
lc = [type(occ) in [int, np.int], hasattr(occ, '__iter__')]
assert any(lc)
lc = [isinstance(occ, _INT), hasattr(occ, '__iter__')]
assert any(lc), occ

if lc[0]:
occ = [np.r_[occ].astype(int) for _ in range(nids)]
Expand Down Expand Up @@ -1482,13 +1497,12 @@ def _checkformat_getdata_indt(self, indt):
msg += " - None: all channels used\n"
msg += " - int: times to use (index)\n"
msg += " - array of int: times to use (indices)"
lc = [type(indt) is None, type(indt) is int, hasattr(indt,'__iter__')]
lc = [type(indt) is None, isinstance(indt, _INT), hasattr(indt,'__iter__')]
if not any(lc):
raise Exception(msg)
if lc[1] or lc[2]:
indt = np.r_[indt].rave()
lc = [indt.dtype == np.int]
if not any(lc):
if indt.dtype not in _NINT:
raise Exception(msg)
assert np.all(indt>=0)
return indt
Expand Down Expand Up @@ -1671,18 +1685,26 @@ def get_lidsidd_shotExp(self, lidsok,
dids=self._dids, didd=self._didd)

def _get_t0(self, t0=None, ind=None):

if ind is None:
ind = False
assert ind is False or isinstance(ind, int)
assert ind is False or isinstance(ind, _INT)

if t0 is None:
t0 = _defimas2tofu._T0

elif t0 != False:
if type(t0) in [int, float, np.int_, np.float64]:

if isinstance(t0, _NUMB):
t0 = float(t0)

elif type(t0) is str:
t0 = t0.strip()
c0 = (len(t0.split('.')) <= 2
and all([ss.isdecimal() for ss in t0.split('.')]))
c0 = (
len(t0.split('.')) <= 2
and all([ss.isdecimal() for ss in t0.split('.')])
)

if 'pulse_schedule' in self._dids.keys():
events = self.get_data(
dsig={'pulse_schedule': ['events_names',
Expand Down Expand Up @@ -1712,6 +1734,7 @@ def _get_t0(self, t0=None, ind=None):
t0 = False
else:
t0 = False

if t0 is False:
msg = "t0 set to False because could not be interpreted !"
warnings.warn(msg)
Expand Down Expand Up @@ -2316,21 +2339,24 @@ def get_tlim(
"""
names, times = None, None
c0 = (isinstance(tlim, list)
and all([type(tt) in [float, int, np.float64, np.int_]
for tt in tlim]))
c0 = (
isinstance(tlim, list)
and all([isinstance(tt, _NUMB) for tt in tlim])
)
if not c0 and 'pulse_schedule' in self._dids.keys():
try:
names, times = self.get_events(verb=False, returnas=tuple)
except Exception as err:
msg = (str(err)
+ "\nEvents not loaded from ids pulse_schedule!")
warnings.warn(msg)

if 'pulse_schedule' in self._dids.keys():
idd = self._dids['pulse_schedule']['idd']
Exp = self._didd[idd]['params']['database']
else:
Exp = None

return _comp_toobjects.data_checkformat_tlim(t, tlim=tlim,
names=names, times=times,
indevent=indevent,
Expand Down
8 changes: 5 additions & 3 deletions tofu/imas2tofu/_mat2ids2calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
+ " => Maybe corrupted data ?\n")


_LTYPES = (int, float, np.integer, np.float64)


# ####################################################
# Utility
# ####################################################
Expand All @@ -34,8 +37,7 @@ def _get_indtlim(t, tlim=None, shot=None, out=bool):
tlim = [-np.inf, np.inf]
else:
assert len(tlim) == 2
ls = [int, float, np.int64, np.float64] # , str
assert all([tt is None or type(tt) in ls for tt in tlim])
assert all([tt is None or isinstance(tt, _LTYPES) for tt in tlim])
tlim = list(tlim)
for (ii, sgn) in [(0, -1.), (1, 1.)]:
if tlim[ii] is None:
Expand Down Expand Up @@ -203,4 +205,4 @@ def get_data_from_matids(input_pfe=None, tlim=None,
if 'brem' in return_fields:
dout['brem'] = _physics.compute_bremzeff(Te=Te, ne=ne,
zeff=zeff, lamb=lamb)[0]
return dout
return dout
4 changes: 2 additions & 2 deletions tofu/nist2tofu/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
}


_LTYPES = [int, float, np.int_, np.float64]
_LTYPES = (int, float, np.integer, np.float64)


_DCERTIFICATES_BUNDLE = {
Expand Down Expand Up @@ -221,7 +221,7 @@ def _get_totalurl(
if v0 is None:
dlamb[k0] = ''
else:
c0 = type(v0) in _LTYPES
c0 = isinstance(v0, _LTYPES)
if not c0:
msg = (
"Arg {} must be a float!\n".format(k0)
Expand Down
Loading

0 comments on commit 3a6820d

Please sign in to comment.