Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cull duplicate dataURIs for MAST in download_products #2497

Merged
merged 7 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ linelists.cdms
- Fix issues with the line name parser and the line data parser; the original
implementation was incomplete and upstream was not fully documented. [#2385, #2411]

mast
^^^^

- Cull duplicate downloads for the same dataURI in ``Observations.download_products()``
and duplicate URIs in ``Observations.get_cloud_uris``. [#2497]

oac
^^^

Expand Down
12 changes: 10 additions & 2 deletions astroquery/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

__all__ = ['TimeoutError', 'InvalidQueryError', 'RemoteServiceError',
'TableParseError', 'LoginError', 'ResolverError',
'NoResultsWarning', 'LargeQueryWarning', 'InputWarning',
'AuthenticationWarning', 'MaxResultsWarning', 'CorruptDataWarning']
'NoResultsWarning', 'DuplicateResultsWarning', 'LargeQueryWarning',
'InputWarning', 'AuthenticationWarning', 'MaxResultsWarning',
'CorruptDataWarning']


class TimeoutError(Exception):
Expand Down Expand Up @@ -67,6 +68,13 @@ class NoResultsWarning(AstropyWarning):
pass


class DuplicateResultsWarning(AstropyWarning):
"""
Astroquery warning class to be issued when a query returns no result.
"""
pass


class LargeQueryWarning(AstropyWarning):
"""
Astroquery warning class to be issued when a query is larger than
Expand Down
34 changes: 32 additions & 2 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import astropy.units as u
import astropy.coordinates as coord

from astropy.table import Table, Row, vstack, MaskedColumn
from astropy.table import Table, Row, unique, vstack, MaskedColumn
from astroquery import log

from astropy.utils import deprecated
Expand All @@ -31,7 +31,7 @@
from ..utils import commons, async_to_sync
from ..utils.class_or_instance import class_or_instance
from ..exceptions import (TimeoutError, InvalidQueryError, RemoteServiceError,
ResolverError, MaxResultsWarning,
ResolverError, MaxResultsWarning, DuplicateResultsWarning,
NoResultsWarning, InputWarning, AuthenticationWarning)

from . import conf, utils
Expand Down Expand Up @@ -714,6 +714,9 @@ def download_products(self, products, *, download_dir=None,

products = vstack(product_lists)

# Remove duplicate products
products = self._remove_duplicate_products(products)

# apply filters
products = self.filter_products(products, mrp_only=mrp_only, **filters)

Expand Down Expand Up @@ -765,6 +768,9 @@ def get_cloud_uris(self, data_products, *, include_bucket=True, full_url=False):
raise RemoteServiceError('Please enable anonymous cloud access by calling `enable_cloud_dataset` method. '
'See MAST Labs documentation for an example: https://mast-labs.stsci.io/#example-data-access-with-astroquery-observations')

# Remove duplicate products
data_products = self._remove_duplicate_products(data_products)

return self._cloud_connection.get_cloud_uri_list(data_products, include_bucket, full_url)

def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
Expand Down Expand Up @@ -800,6 +806,30 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
# Query for product URIs
return self._cloud_connection.get_cloud_uri(data_product, include_bucket, full_url)

def _remove_duplicate_products(self, data_products):
"""
Removes duplicate data products that have the same dataURI.

Parameters
----------
data_products : `~astropy.table.Table`
Table containing products to be checked for duplicates.

Returns
-------
unique_products : `~astropy.table.Table`
Table containing products with unique dataURIs.

"""
number = len(data_products)
unique_products = unique(data_products, keys="dataURI")
number_unique = len(unique_products)
if number_unique < number:
warnings.warn(f"{number - number_unique} of {number} products were duplicates."
f"Only downloading {number_unique} unique product(s).", DuplicateResultsWarning)

return unique_products


@async_to_sync
class MastClass(MastQueryWithLogin):
Expand Down
36 changes: 33 additions & 3 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from astroquery import mast

from ..utils import ResolverError
from ...exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning, RemoteServiceError
from ...exceptions import (InvalidQueryError, MaxResultsWarning, NoResultsWarning,
DuplicateResultsWarning, RemoteServiceError)


OBSID = '1647157'
Expand Down Expand Up @@ -274,7 +275,7 @@ def test_observations_download_products(self, tmpdir):
assert os.path.isfile(row['Local Path'])

# just get the curl script
result = mast.Observations.download_products(test_obs[0]["obsid"],
result = mast.Observations.download_products(test_obs_id[0]["obsid"],
download_dir=str(tmpdir),
curl_flag=True,
productType=["SCIENCE"],
Expand All @@ -283,12 +284,41 @@ def test_observations_download_products(self, tmpdir):
assert os.path.isfile(result['Local Path'][0])

# check for row input
result1 = mast.Observations.get_product_list(test_obs[0]["obsid"])
result1 = mast.Observations.get_product_list(test_obs_id[0]["obsid"])
result2 = mast.Observations.download_products(result1[0])
assert isinstance(result2, Table)
assert os.path.isfile(result2['Local Path'][0])
assert len(result2) == 1

def test_observations_download_products_no_duplicates(tmpdir):

# Pull products for a JWST NIRSpec MSA observation with 6 known
# duplicates of the MSA configuration file, propID=2736
products = mast.Observations.get_product_list("87602009")

# Filter out everything but the MSA config file
mask = np.char.find(products["dataURI"], "_msa.fits") != -1
products = products[mask]

assert len(products) == 6

# Download the product
with pytest.warns(DuplicateResultsWarning):
manifest = mast.Observations.download_products(products,
download_dir=str(tmpdir))

# Check that it downloads the MSA config file only once
assert len(manifest) == 1

# enable access to public AWS S3 bucket
mast.Observations.enable_cloud_dataset()

# Check duplicate cloud URIs as well
with pytest.warns(DuplicateResultsWarning):
uris = mast.Observations.get_cloud_uris(products)

assert len(uris) == 1

def test_observations_download_file(self, tmpdir):

# enabling cloud connection
Expand Down