Skip to content

Commit 067c6d2

Browse files
authored
Merge pull request #3454 from snbianco/obs-product-batch
Implement batching in `Observations.get_product_list` and add `batch_size` parameter
2 parents dd9f00d + 8068e86 commit 067c6d2

File tree

8 files changed

+126
-63
lines changed

8 files changed

+126
-63
lines changed

CHANGES.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ mast
5252

5353
- Raise informative error if ``MastMissions`` query radius is too large. [#3447]
5454

55+
- Add ``batch_size`` parameter to ``MastMissions.get_product_list``, ``Observations.get_product_list``,
56+
and ``utils.resolve_object`` to allow controlling the number of items sent in each batch request to the server.
57+
This can help avoid timeouts or connection errors for large requests. [#3454]
58+
5559
jplspec
5660
^^^^^^^
5761

astroquery/mast/missions.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,14 @@ def _extract_products(self, response):
9797
list
9898
A list of products extracted from the response.
9999
"""
100-
def normalize_products(products):
101-
"""
102-
Normalize the products list to ensure it is flat and not nested.
103-
"""
100+
combined = []
101+
for resp in response:
102+
products = resp.json().get('products', [])
103+
# Flatten if nested
104104
if products and isinstance(products[0], list):
105-
return products[0]
106-
return products
107-
108-
if isinstance(response, list): # multiple async responses from batching
109-
combined = []
110-
for resp in response:
111-
products = normalize_products(resp.json().get('products', []))
112-
combined.extend(products)
113-
return combined
114-
else: # single response
115-
return normalize_products(response.json().get('products', []))
105+
products = products[0]
106+
combined.extend(products)
107+
return combined
116108

117109
def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality
118110
"""
@@ -417,7 +409,7 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse
417409
select_cols=select_cols, **criteria)
418410

419411
@class_or_instance
420-
def get_product_list_async(self, datasets):
412+
def get_product_list_async(self, datasets, *, batch_size=1000):
421413
"""
422414
Given a dataset ID or list of dataset IDs, returns a list of associated data products.
423415
@@ -428,6 +420,9 @@ def get_product_list_async(self, datasets):
428420
datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table`
429421
Row/Table of MastMissions query results (e.g. output from `query_object`)
430422
or single/list of dataset ID(s).
423+
batch_size : int, optional
424+
Default 1000. Number of dataset IDs to include in each batch request to the server.
425+
If you experience timeouts or connection errors, consider lowering this value.
431426
432427
Returns
433428
-------
@@ -439,8 +434,8 @@ def get_product_list_async(self, datasets):
439434
if isinstance(datasets, Table) or isinstance(datasets, Row):
440435
dataset_kwd = self.get_dataset_kwd()
441436
if not dataset_kwd:
442-
log.warning('Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
443-
return None
437+
raise InvalidQueryError(f'Dataset keyword not found for mission "{self.mission}". Please input '
438+
'dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
444439

445440
# Extract dataset IDs based on input type and mission
446441
if isinstance(datasets, Table):
@@ -466,17 +461,17 @@ def get_product_list_async(self, datasets):
466461
results = utils._batched_request(
467462
datasets,
468463
params={},
469-
max_batch=1000,
464+
max_batch=batch_size,
470465
param_key="dataset_ids",
471466
request_func=lambda p: self._service_api_connection.missions_request_async(self.service, p),
472467
extract_func=lambda r: [r], # missions_request_async already returns one result
473468
desc=f"Fetching products for {len(datasets)} unique datasets"
474469
)
475470

476-
# Return a list of responses only if multiple requests were made
477-
return results[0] if len(results) == 1 else results
471+
# Return a list of responses
472+
return results
478473

479-
def get_unique_product_list(self, datasets):
474+
def get_unique_product_list(self, datasets, *, batch_size=1000):
480475
"""
481476
Given a dataset ID or list of dataset IDs, returns a list of associated data products with unique
482477
filenames.
@@ -486,13 +481,16 @@ def get_unique_product_list(self, datasets):
486481
datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table`
487482
Row/Table of MastMissions query results (e.g. output from `query_object`)
488483
or single/list of dataset ID(s).
484+
batch_size : int, optional
485+
Default 1000. Number of dataset IDs to include in each batch request to the server.
486+
If you experience timeouts or connection errors, consider lowering this value.
489487
490488
Returns
491489
-------
492490
unique_products : `~astropy.table.Table`
493491
Table containing products with unique URIs.
494492
"""
495-
products = self.get_product_list(datasets)
493+
products = self.get_product_list(datasets, batch_size=batch_size)
496494
unique_products = utils.remove_duplicate_products(products, 'filename')
497495
if len(unique_products) < len(products):
498496
log.info("To return all products, use `MastMissions.get_product_list`")

astroquery/mast/observations.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _filter_ffi_observations(self, observations):
504504
return obs_table[mask]
505505

506506
@class_or_instance
507-
def get_product_list_async(self, observations):
507+
def get_product_list_async(self, observations, *, batch_size=500):
508508
"""
509509
Given a "Product Group Id" (column name obsid) returns a list of associated data products.
510510
Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in
@@ -518,31 +518,50 @@ def get_product_list_async(self, observations):
518518
Row/Table of MAST query results (e.g. output from `query_object`)
519519
or single/list of MAST Product Group Id(s) (obsid).
520520
See description `here <https://masttest.stsci.edu/api/v0/_c_a_o_mfields.html>`__.
521+
batch_size : int, optional
522+
Default 500. Number of obsids to include in each batch request to the server.
523+
If you experience timeouts or connection errors, consider lowering this value.
521524
522525
Returns
523526
-------
524527
response : list of `~requests.Response`
528+
A list of asynchronous response objects for each batch request.
525529
"""
526-
527-
# getting the obsid list
530+
# Getting the obsids as a list
528531
if np.isscalar(observations):
529-
observations = np.array([observations])
530-
if isinstance(observations, Table) or isinstance(observations, Row):
532+
observations = [observations]
533+
elif isinstance(observations, (Row, Table)):
531534
# Filter out TESS FFIs and TICA FFIs
532535
# Can only perform filtering on Row or Table because of access to `target_name` field
533536
observations = self._filter_ffi_observations(observations)
534-
observations = observations['obsid']
535-
if isinstance(observations, list):
536-
observations = np.array(observations)
537-
538-
observations = observations[observations != ""]
539-
if observations.size == 0:
540-
raise InvalidQueryError("Observation list is empty, no associated products.")
541-
542-
service = self._caom_products
543-
params = {'obsid': ','.join(observations)}
544-
545-
return self._portal_api_connection.service_request_async(service, params)
537+
observations = observations['obsid'].tolist()
538+
539+
# Clean and validate
540+
observations = [str(obs).strip() for obs in observations]
541+
observations = [obs for obs in observations if obs]
542+
if not observations:
543+
raise InvalidQueryError('Observation list is empty, no associated products.')
544+
545+
# Define a helper to join obsids for each batch request
546+
def _request_joined_obsid(params):
547+
"""Join batched obsid list into comma-separated string and send async request."""
548+
pp = dict(params)
549+
vals = pp.get('obsid', [])
550+
pp['obsid'] = ','.join(map(str, vals))
551+
return self._portal_api_connection.service_request_async(self._caom_products, pp)[0]
552+
553+
# Perform batched requests
554+
results = utils._batched_request(
555+
items=observations,
556+
params={},
557+
max_batch=batch_size,
558+
param_key='obsid',
559+
request_func=_request_joined_obsid,
560+
extract_func=lambda r: [r],
561+
desc=f'Fetching products for {len(observations)} unique observations'
562+
)
563+
564+
return results
546565

547566
def filter_products(self, products, *, mrp_only=False, extension=None, **filters):
548567
"""
@@ -1029,7 +1048,7 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
10291048
# Query for product URIs
10301049
return self._cloud_connection.get_cloud_uri(data_product, include_bucket, full_url)
10311050

1032-
def get_unique_product_list(self, observations):
1051+
def get_unique_product_list(self, observations, *, batch_size=500):
10331052
"""
10341053
Given a "Product Group Id" (column name obsid), returns a list of associated data products with
10351054
unique dataURIs. Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in
@@ -1041,13 +1060,16 @@ def get_unique_product_list(self, observations):
10411060
Row/Table of MAST query results (e.g. output from `query_object`)
10421061
or single/list of MAST Product Group Id(s) (obsid).
10431062
See description `here <https://masttest.stsci.edu/api/v0/_c_a_o_mfields.html>`__.
1063+
batch_size : int, optional
1064+
Default 500. Number of obsids to include in each batch request to the server.
1065+
If you experience timeouts or connection errors, consider lowering this value.
10441066
10451067
Returns
10461068
-------
10471069
unique_products : `~astropy.table.Table`
10481070
Table containing products with unique dataURIs.
10491071
"""
1050-
products = self.get_product_list(observations)
1072+
products = self.get_product_list(observations, batch_size=batch_size)
10511073
unique_products = utils.remove_duplicate_products(products, 'dataURI')
10521074
if len(unique_products) < len(products):
10531075
log.info("To return all products, use `Observations.get_product_list`")

astroquery/mast/tests/test_mast.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,21 +305,21 @@ def test_missions_query_criteria(patch_post):
305305
def test_missions_get_product_list_async(patch_post):
306306
# String input
307307
result = mast.MastMissions.get_product_list_async('Z14Z0104T')
308-
assert isinstance(result, MockResponse)
308+
assert isinstance(result, list)
309309

310310
# List input
311311
in_datasets = ['Z14Z0104T', 'Z14Z0102T']
312312
result = mast.MastMissions.get_product_list_async(in_datasets)
313-
assert isinstance(result, MockResponse)
313+
assert isinstance(result, list)
314314

315315
# Row input
316316
datasets = mast.MastMissions.query_object("M101", radius=".002 deg")
317317
result = mast.MastMissions.get_product_list_async(datasets[:3])
318-
assert isinstance(result, MockResponse)
318+
assert isinstance(result, list)
319319

320320
# Table input
321321
result = mast.MastMissions.get_product_list_async(datasets[0])
322-
assert isinstance(result, MockResponse)
322+
assert isinstance(result, list)
323323

324324
# Unsupported data type for datasets
325325
with pytest.raises(TypeError) as err_type:
@@ -331,6 +331,11 @@ def test_missions_get_product_list_async(patch_post):
331331
mast.MastMissions.get_product_list_async([' '])
332332
assert 'Dataset list is empty' in str(err_empty.value)
333333

334+
# No dataset keyword
335+
with pytest.raises(InvalidQueryError, match='Dataset keyword not found for mission "invalid"'):
336+
missions = mast.MastMissions(mission='invalid')
337+
missions.get_product_list_async(Table({'a': [1, 2, 3]}))
338+
334339

335340
def test_missions_get_product_list(patch_post):
336341
# String input
@@ -825,6 +830,10 @@ def test_observations_get_product_list(patch_post):
825830
result = mast.Observations.get_product_list(in_obsids)
826831
assert isinstance(result, Table)
827832

833+
# Error if no valid obsids are found
834+
with pytest.raises(InvalidQueryError, match='Observation list is empty'):
835+
mast.Observations.get_product_list([' '])
836+
828837

829838
def test_observations_filter_products(patch_post):
830839
products = mast.Observations.get_product_list('2003738726')

astroquery/mast/tests/test_mast_remote.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,25 @@ def test_missions_get_product_list_async(self):
199199

200200
# Table as input
201201
responses = MastMissions.get_product_list_async(datasets[:3])
202-
assert isinstance(responses, Response)
202+
assert isinstance(responses, list)
203203

204204
# Row as input
205205
responses = MastMissions.get_product_list_async(datasets[0])
206-
assert isinstance(responses, Response)
206+
assert isinstance(responses, list)
207207

208208
# String as input
209209
responses = MastMissions.get_product_list_async(datasets[0]['sci_data_set_name'])
210-
assert isinstance(responses, Response)
210+
assert isinstance(responses, list)
211211

212212
# Column as input
213213
responses = MastMissions.get_product_list_async(datasets[:3]['sci_data_set_name'])
214-
assert isinstance(responses, Response)
214+
assert isinstance(responses, list)
215+
216+
# Batching
217+
responses = MastMissions.get_product_list_async(datasets[:4], batch_size=2)
218+
assert isinstance(responses, list)
219+
assert len(responses) == 2
220+
assert isinstance(responses[0], Response)
215221

216222
# Unsupported data type for datasets
217223
with pytest.raises(TypeError) as err_type:
@@ -248,14 +254,13 @@ def test_missions_get_product_list(self, capsys):
248254
assert isinstance(result, Table)
249255
assert (result['dataset'] == 'IBKH03020').all()
250256

251-
# Test batching by creating a list of 1001 different strings
252-
# This won't return any results, but will test the batching
253-
dataset_list = [f'{i}' for i in range(1001)]
254-
result = MastMissions.get_product_list(dataset_list)
257+
# Test batching
258+
result_batch = MastMissions.get_product_list(datasets[:2], batch_size=1)
255259
out, _ = capsys.readouterr()
256-
assert isinstance(result, Table)
257-
assert len(result) == 0
258-
assert 'Fetching products for 1001 unique datasets in 2 batches' in out
260+
assert isinstance(result_batch, Table)
261+
assert len(result_batch) == len(result_table)
262+
assert set(result_batch['filename']) == set(result_table['filename'])
263+
assert 'Fetching products for 2 unique datasets in 2 batches' in out
259264

260265
def test_missions_get_unique_product_list(self, caplog):
261266
# Check that no rows are filtered out when all products are unique
@@ -593,7 +598,11 @@ def test_observations_get_product_list_async(self):
593598
responses = Observations.get_product_list_async(observations[0:4])
594599
assert isinstance(responses, list)
595600

596-
def test_observations_get_product_list(self):
601+
# Batching
602+
responses = Observations.get_product_list_async(observations[0:4], batch_size=2)
603+
assert isinstance(responses, list)
604+
605+
def test_observations_get_product_list(self, capsys):
597606
observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE'])
598607
test_obs_id = str(observations[0]['obsid'])
599608
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
@@ -626,6 +635,14 @@ def test_observations_get_product_list(self):
626635
assert len(obs_collection) == 1
627636
assert obs_collection[0] == 'IUE'
628637

638+
# Test batching
639+
result_batch = Observations.get_product_list(observations[:2], batch_size=1)
640+
out, _ = capsys.readouterr()
641+
assert isinstance(result_batch, Table)
642+
assert len(result_batch) == len(result1)
643+
assert set(result_batch['productFilename']) == set(filenames1)
644+
assert 'Fetching products for 2 unique observations in 2 batches' in out
645+
629646
def test_observations_get_product_list_tess_tica(self, caplog):
630647
# Get observations and products with both TESS and TICA FFIs
631648
obs = Observations.query_criteria(target_name=['TESS FFI', 'TICA FFI', '429031146'])

astroquery/mast/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _batched_request(
148148
return extract_func(resp)
149149

150150

151-
def resolve_object(objectname, *, resolver=None, resolve_all=False):
151+
def resolve_object(objectname, *, resolver=None, resolve_all=False, batch_size=30):
152152
"""
153153
Resolves one or more object names to a position on the sky.
154154
@@ -164,6 +164,9 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
164164
resolve_all : bool, optional
165165
If True, will try to resolve the object name using all available resolvers ("NED", "SIMBAD").
166166
Default is False.
167+
batch_size : int, optional
168+
Default 30. Number of object names to include in each batch request to the server.
169+
If you experience timeouts or connection errors, consider lowering this value.
167170
168171
Returns
169172
-------
@@ -230,7 +233,7 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
230233
results = _batched_request(
231234
object_names,
232235
params,
233-
max_batch=30,
236+
max_batch=batch_size,
234237
param_key="name",
235238
request_func=lambda p: _simple_request("http://mastresolver.stsci.edu/Santa-war/query", p),
236239
extract_func=lambda r: r.json().get("resolvedCoordinate") or [],

docs/mast/mast_missions.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,16 @@ Each observation returned from a MAST query can have one or more associated data
203203
one or more datasets or dataset IDs, the `~astroquery.mast.MastMissionsClass.get_product_list` function
204204
will return a `~astropy.table.Table` containing the associated data products.
205205

206+
`~astroquery.mast.MastMissionsClass.get_product_list` also includes an optional ``batch_size`` parameter,
207+
which controls how many datasets are sent to the MAST service per request. This can be useful for managing
208+
memory usage or avoiding timeouts when requesting product lists for large numbers of datasets.
209+
If not provided, batch_size defaults to 1000.
210+
206211
.. doctest-remote-data::
207212
>>> datasets = missions.query_criteria(sci_pep_id=12451,
208213
... sci_instrume='ACS',
209214
... sci_hlsp='>1')
210-
>>> products = missions.get_product_list(datasets[:2])
215+
>>> products = missions.get_product_list(datasets[:2], batch_size=1000)
211216
>>> print(products[:5]) # doctest: +IGNORE_OUTPUT
212217
product_key access dataset ... category size type
213218
---------------------------- ------ --------- ... ---------- --------- -------

0 commit comments

Comments
 (0)