From 7a3d8d686b08493cbb885bcb822ea2dc77b73b0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=C3=A9drine?= Date: Fri, 7 Jul 2023 13:44:05 +0200 Subject: [PATCH] Updated doc and fixed missing cases in open_fits --- vip_hci/fits/fits.py | 48 ++++++++++++++++++++----------------- vip_hci/objects/postproc.py | 1 - 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/vip_hci/fits/fits.py b/vip_hci/fits/fits.py index c9853a7b..3d027ac7 100644 --- a/vip_hci/fits/fits.py +++ b/vip_hci/fits/fits.py @@ -33,7 +33,8 @@ def open_fits( fitsfilename : string or pathlib.Path Name of the fits file or ``pathlib.Path`` object n : int, optional - It chooses which HDU to open. Default is the first one. + It chooses which HDU to open. Default is the first one. If n is equal + to -2, opens and returns all extensions. header : bool, optional Whether to return the header along with the data or not. precision : numpy dtype, optional @@ -54,12 +55,15 @@ def open_fits( Returns ------- - hdulist : hdulist - [memmap=True] FITS file ``n`` hdulist. - data : numpy ndarray - [memmap=False] Array containing the frames of the fits-cube. - header : dict + hdulist : HDU or HDUList + [memmap=True] FITS file ``n`` hdulist. If n equals -2, returns the whole + hdulist. + data : numpy ndarray or list of numpy ndarrays + [memmap=False] Array containing the frames of the fits-cube. If n + equals -2, returns a list of all arrays. + header : dict or list of dict [memmap=False, header=True] Dictionary containing the fits header. + If n equals -2, returns a list of all dictionnaries. """ fitsfilename = str(fitsfilename) @@ -74,11 +78,13 @@ def open_fits( if n == ALL_FITS: data_list = [] header_list = [] + if return_memmap: + return hdulist + for index, element in enumerate(hdulist): data, header = _return_data_fits( hdulist=hdulist, index=index, - return_memmap=return_memmap, precision=precision, verbose=verbose, ) @@ -96,10 +102,12 @@ def open_fits( return data_list # Opening only a specified extension else: + if return_memmap: + return hdulist[n] + data, header = _return_data_fits( hdulist=hdulist, index=n, - return_memmap=return_memmap, precision=precision, verbose=verbose, ) @@ -113,7 +121,6 @@ def open_fits( def _return_data_fits( hdulist: HDUList, index: int, - return_memmap: bool, precision=np.float32, verbose: bool = True, ): @@ -127,19 +134,16 @@ def _return_data_fits( index : int The wanted index to extract. """ - if return_memmap: - return hdulist[index] - else: - data = hdulist[index].data - data = np.array(data, dtype=precision) - header = hdulist[index].header - - if verbose: - print( - f"Fits HDU-{index} data successfully loaded, header available. " - f"Data shape: {data.shape}" - ) - return data, header + data = hdulist[index].data + data = np.array(data, dtype=precision) + header = hdulist[index].header + + if verbose: + print( + f"Fits HDU-{index} data successfully loaded, header available. " + f"Data shape: {data.shape}" + ) + return data, header def byteswap_array(array): diff --git a/vip_hci/objects/postproc.py b/vip_hci/objects/postproc.py index 27513188..17e8c855 100644 --- a/vip_hci/objects/postproc.py +++ b/vip_hci/objects/postproc.py @@ -221,7 +221,6 @@ def results_to_fits(self, filepath: str) -> None: " a session with the function `register_session`." ) - # Note: unfinished def fits_to_results(self, filepath: str, session_id: int = ALL_FITS) -> None: """ Load all configurations from a fits file.