Skip to content

Commit 1e01ac3

Browse files
authored
Merge pull request #3403 from zoghbi-a/heasarc-query-by-column
2 parents 1a165c4 + 5c2ee14 commit 1e01ac3

File tree

5 files changed

+643
-55
lines changed

5 files changed

+643
-55
lines changed

CHANGES.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ esa.euclid
2020
Service fixes and enhancements
2121
------------------------------
2222

23+
heasarc
24+
^^^^^^^
25+
- Add ``query_constraints`` to allow querying of different catalog columns. [#3403]
26+
- Add support for uploading tables when using TAP directly through ``query_tap``. [#3403]
27+
- Add automatic guessing for the data host in ``download_data``. [#3403]
28+
29+
2330
esa.hubble
2431
^^^^^^^^^^
2532

astroquery/heasarc/core.py

Lines changed: 225 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
21
import os
3-
42
import shutil
53
import requests
64
import tarfile
@@ -261,7 +259,7 @@ def query_mission_cols(self, mission, *, cache=True,
261259
cols = [col.upper() for col in cols['name'] if '__' not in col]
262260
return cols
263261

264-
def query_tap(self, query, *, maxrec=None):
262+
def query_tap(self, query, *, maxrec=None, uploads=None):
265263
"""
266264
Send query to HEASARC's Xamin TAP using ADQL.
267265
Results in `~pyvo.dal.TAPResults` format.
@@ -273,6 +271,10 @@ def query_tap(self, query, *, maxrec=None):
273271
ADQL query to be executed
274272
maxrec : int
275273
maximum number of records to return
274+
uploads : dict
275+
a mapping from table names used in the query to file like
276+
objects containing a votable
277+
(e.g. a file path or `~astropy.table.Table`).
276278
277279
Returns
278280
-------
@@ -286,7 +288,130 @@ def query_tap(self, query, *, maxrec=None):
286288
"""
287289
log.debug(f'TAP query: {query}')
288290
self._saved_query = query
289-
return self.tap.search(query, language='ADQL', maxrec=maxrec)
291+
return self.tap.search(
292+
query, language='ADQL', maxrec=maxrec, uploads=uploads)
293+
294+
def _query_execute(self, catalog=None, where=None, *,
295+
get_query_payload=False, columns=None,
296+
verbose=False, maxrec=None):
297+
"""Queries some catalog using the HEASARC TAP server based on the
298+
'where' condition and returns an `~astropy.table.Table`.
299+
300+
Parameters
301+
----------
302+
catalog : str
303+
The catalog to query. To list the available catalogs, use
304+
:meth:`~astroquery.heasarc.HeasarcClass.list_catalogs`.
305+
where : str
306+
The WHERE condition to be used in the query. It must
307+
include the 'WHERE' keyword or be empty.
308+
get_query_payload : bool, optional
309+
If `True` then returns the generated ADQL query as str.
310+
Defaults to `False`.
311+
columns : str, optional
312+
Target column list with value separated by a comma(,).
313+
Use * for all the columns. The default is to return a subset
314+
of the columns that are generally the most useful.
315+
verbose : bool, optional
316+
If False, suppress vo warnings.
317+
maxrec : int, optional
318+
Maximum number of records
319+
320+
321+
Returns
322+
-------
323+
table : A `~astropy.table.Table` object.
324+
"""
325+
# if verbose is False then suppress any VOTable related warnings
326+
if not verbose:
327+
commons.suppress_vo_warnings()
328+
329+
if catalog is None:
330+
raise InvalidQueryError("catalog name is required! Use 'xray' "
331+
"to search the master X-ray catalog")
332+
333+
if where is None:
334+
where = ''
335+
336+
# __row is needed for locate_data; we add it if not already present
337+
# and remove it afterwards only if the user requested specific
338+
# columns. keep_row tracks that.
339+
keep_row = (
340+
columns in (None, '*')
341+
or isinstance(columns, str) and '__row' in columns
342+
)
343+
344+
if columns is None:
345+
columns = ', '.join(self._get_default_columns(catalog))
346+
347+
if '__row' not in columns and columns != '*':
348+
columns += ', __row'
349+
350+
if where != '' and not where.startswith(' '):
351+
where = ' ' + where.strip()
352+
adql = f'SELECT {columns} FROM {catalog}{where}'
353+
354+
if get_query_payload:
355+
return adql
356+
response = self.query_tap(query=adql, maxrec=maxrec)
357+
358+
# save the response in case we want to use it later
359+
self._last_result = response
360+
self._last_catalog_name = catalog
361+
362+
table = response.to_table()
363+
if not keep_row and '__row' in table.colnames:
364+
table.remove_column('__row')
365+
return table
366+
367+
def _parse_constraints(self, column_filters):
368+
"""Convert constraints dictionary to ADQL WHERE clause
369+
370+
Parameters
371+
----------
372+
column_filters : dict
373+
A dictionary of column constraint filters to include in the query.
374+
Each key-value pair will be translated into an ADQL condition.
375+
See `query_region` for details.
376+
377+
Returns
378+
-------
379+
conditions : list
380+
a list of ADQL conditions as str
381+
382+
"""
383+
conditions = []
384+
if column_filters is None:
385+
return conditions
386+
for key, value in column_filters.items():
387+
if isinstance(value, tuple):
388+
if (
389+
len(value) == 2
390+
and all(isinstance(v, (int, float)) for v in value)
391+
):
392+
conditions.append(
393+
f"{key} BETWEEN {value[0]} AND {value[1]}"
394+
)
395+
elif (
396+
len(value) == 2
397+
and value[0] in (">", "<", ">=", "<=")
398+
):
399+
conditions.append(f"{key} {value[0]} {value[1]}")
400+
elif isinstance(value, list):
401+
# handle list values: key IN (...)
402+
formatted = []
403+
for v in value:
404+
if isinstance(v, str):
405+
formatted.append(f"'{v}'")
406+
else:
407+
formatted.append(str(v))
408+
conditions.append(f"{key} IN ({', '.join(formatted)})")
409+
else:
410+
conditions.append(
411+
f"{key} = '{value}'"
412+
if isinstance(value, str) else f"{key} = {value}"
413+
)
414+
return conditions
290415

291416
@deprecated_renamed_argument(
292417
('mission', 'fields', 'resultmax', 'entry', 'coordsys', 'equinox',
@@ -298,8 +423,8 @@ def query_tap(self, query, *, maxrec=None):
298423
True, True, True, False)
299424
)
300425
def query_region(self, position=None, catalog=None, radius=None, *,
301-
spatial='cone', width=None, polygon=None, add_offset=False,
302-
get_query_payload=False, columns=None, cache=False,
426+
spatial='cone', width=None, polygon=None, column_filters=None,
427+
add_offset=False, get_query_payload=False, columns=None, cache=False,
303428
verbose=False, maxrec=None,
304429
**kwargs):
305430
"""Queries the HEASARC TAP server around a coordinate and returns a
@@ -335,6 +460,23 @@ def query_region(self, position=None, catalog=None, radius=None, *,
335460
outlining the polygon to search in. It can also be a list of
336461
`astropy.coordinates` object or strings that can be parsed by
337462
`astropy.coordinates.ICRS`.
463+
column_filters : dict
464+
A dictionary of column constraint filters to include in the query.
465+
Each key-value pair will be translated into an ADQL condition.
466+
- For a range query, use a tuple of two values (min, max).
467+
e.g. ``{'flux': (1e-12, 1e-10)}`` translates to
468+
``flux BETWEEN 1e-12 AND 1e-10``.
469+
- For list values, use a list of values.
470+
e.g. ``{'object_type': ['QSO', 'GALAXY']}`` translates to
471+
``object_type IN ('QSO', 'GALAXY')``.
472+
- For comparison queries, use a tuple of (operator, value),
473+
where operator is one of '=', '!=', '<', '>', '<=', '>='.
474+
e.g. ``{'magnitude': ('<', 15)}`` translates to ``magnitude < 15``.
475+
- For exact matches, use a single value (str, int, float).
476+
e.g. ``{'object_type': 'QSO'}`` translates to
477+
``object_type = 'QSO'``.
478+
The keys should correspond to valid column names in the catalog.
479+
Use `list_columns` to see the available columns.
338480
add_offset: bool
339481
If True and spatial=='cone', add a search_offset column that
340482
indicates the separation (in arcmin) between the requested
@@ -356,18 +498,10 @@ def query_region(self, position=None, catalog=None, radius=None, *,
356498
-------
357499
table : A `~astropy.table.Table` object.
358500
"""
359-
# if verbose is False then suppress any VOTable related warnings
360-
if not verbose:
361-
commons.suppress_vo_warnings()
362501

363-
if catalog is None:
364-
raise InvalidQueryError("catalog name is required! Use 'xray' "
365-
"to search the master X-ray catalog")
366-
367-
if columns is None:
368-
columns = ', '.join(self._get_default_columns(catalog))
369-
if '__row' not in columns:
370-
columns += ',__row'
502+
# if we have column_filters and no position, assume all-sky search
503+
if position is None and column_filters is not None:
504+
spatial = 'all-sky'
371505

372506
if spatial.lower() == 'all-sky':
373507
where = ''
@@ -390,9 +524,14 @@ def query_region(self, position=None, catalog=None, radius=None, *,
390524

391525
coords_str = [f'{coord.ra.deg},{coord.dec.deg}'
392526
for coord in coords_list]
393-
where = (" WHERE CONTAINS(POINT('ICRS',ra,dec),"
527+
where = ("WHERE CONTAINS(POINT('ICRS',ra,dec),"
394528
f"POLYGON('ICRS',{','.join(coords_str)}))=1")
395529
else:
530+
if position is None:
531+
raise InvalidQueryError(
532+
"position is required to for spatial='cone' (default). "
533+
"Use spatial='all-sky' For all-sky searches."
534+
)
396535
coords_icrs = parse_coordinates(position).icrs
397536
ra, dec = coords_icrs.ra.deg, coords_icrs.dec.deg
398537

@@ -401,7 +540,7 @@ def query_region(self, position=None, catalog=None, radius=None, *,
401540
radius = self.get_default_radius(catalog)
402541
elif isinstance(radius, str):
403542
radius = coordinates.Angle(radius)
404-
where = (" WHERE CONTAINS(POINT('ICRS',ra,dec),CIRCLE("
543+
where = ("WHERE CONTAINS(POINT('ICRS',ra,dec),CIRCLE("
405544
f"'ICRS',{ra},{dec},{radius.to(u.deg).value}))=1")
406545
# add search_offset for the case of cone
407546
if add_offset:
@@ -410,24 +549,33 @@ def query_region(self, position=None, catalog=None, radius=None, *,
410549
elif spatial.lower() == 'box':
411550
if isinstance(width, str):
412551
width = coordinates.Angle(width)
413-
where = (" WHERE CONTAINS(POINT('ICRS',ra,dec),"
552+
where = ("WHERE CONTAINS(POINT('ICRS',ra,dec),"
414553
f"BOX('ICRS',{ra},{dec},{width.to(u.deg).value},"
415554
f"{width.to(u.deg).value}))=1")
416555
else:
417556
raise ValueError("Unrecognized spatial query type. Must be one"
418557
" of 'cone', 'box', 'polygon', or 'all-sky'.")
419558

420-
adql = f'SELECT {columns} FROM {catalog}{where}'
421-
559+
# handle column filters
560+
if column_filters is not None:
561+
conditions = self._parse_constraints(column_filters)
562+
if len(conditions) > 0:
563+
constraints_str = ' AND '.join(conditions)
564+
if where == '':
565+
where = 'WHERE ' + constraints_str
566+
else:
567+
where += ' AND ' + constraints_str
568+
569+
table_or_query = self._query_execute(
570+
catalog=catalog, where=where,
571+
get_query_payload=get_query_payload,
572+
columns=columns, verbose=verbose,
573+
maxrec=maxrec
574+
)
422575
if get_query_payload:
423-
return adql
424-
response = self.query_tap(query=adql, maxrec=maxrec)
425-
426-
# save the response in case we want to use it later
427-
self._last_result = response
428-
self._last_catalog_name = catalog
576+
return table_or_query
577+
table = table_or_query
429578

430-
table = response.to_table()
431579
if add_offset:
432580
table['search_offset'].unit = u.arcmin
433581
if len(table) == 0:
@@ -505,18 +653,22 @@ def locate_data(self, query_result=None, catalog_name=None):
505653
if '__row' not in query_result.colnames:
506654
raise ValueError('No __row column found in query_result. '
507655
'query_result needs to be the output of '
508-
'query_region or a subset.')
656+
'query_region or a subset. try adding '
657+
'__row to the requested columns')
509658

510659
if catalog_name is None:
660+
if not hasattr(self, '_last_catalog_name'):
661+
raise ValueError('locate_data needs a catalog_name, and none '
662+
'found from a previous search. Please provide one.')
511663
catalog_name = self._last_catalog_name
512664
if not (
513665
isinstance(catalog_name, str)
514666
and catalog_name in self.tap.tables.keys()
515667
):
516668
raise ValueError(f'Unknown catalog name: {catalog_name}')
517669

518-
# datalink url
519-
dlink_url = f'{self.VO_URL}/datalink/{catalog_name}'
670+
# datalink url; use sizefiles=false to speed up the response
671+
dlink_url = f'{self.VO_URL}/datalink/{catalog_name}?sizefiles=false&'
520672
query = pyvo.dal.adhoc.DatalinkQuery(
521673
baseurl=dlink_url,
522674
id=query_result['__row'],
@@ -592,17 +744,52 @@ def enable_cloud(self, provider='aws', profile=None):
592744

593745
self.s3_client = self.s3_resource.meta.client
594746

595-
def download_data(self, links, host='heasarc', location='.'):
747+
def _guess_host(self, host):
748+
"""Guess the host to use for downloading data
749+
750+
Parameters
751+
----------
752+
host : str
753+
The host provided by the user
754+
755+
Returns
756+
-------
757+
host : str
758+
The guessed host
759+
760+
"""
761+
if host in ['heasarc', 'sciserver', 'aws']:
762+
return host
763+
elif host is not None:
764+
raise ValueError(
765+
'host has to be one of heasarc, sciserver, aws or None')
766+
767+
# host is None, so we guess
768+
if (
769+
'HOME' in os.environ
770+
and os.environ['HOME'] == '/home/idies'
771+
and os.path.exists('/FTP/')
772+
):
773+
# we are on idies, so we can use sciserver
774+
return 'sciserver'
775+
776+
for var in ['AWS_REGION', 'AWS_DEFAULT_REGION', 'AWS_ROLE_ARN']:
777+
if var in os.environ:
778+
return 'aws'
779+
return 'heasarc'
780+
781+
def download_data(self, links, *, host=None, location='.'):
596782
"""Download data products in links with a choice of getting the
597783
data from either the heasarc server, sciserver, or the cloud in AWS.
598784
599785
600786
Parameters
601787
----------
602788
links : `astropy.table.Table` or `astropy.table.Row`
603-
The result from locate_data
604-
host : str
605-
The data host. The options are: heasarc (default), sciserver, aws.
789+
A table (or row) of data links, typically the result of locate_data.
790+
host : str or None
791+
The data host. The options are: None (default), heasarc, sciserver, aws.
792+
If None, the host is guessed based on the environment.
606793
If host == 'sciserver', data is copied from the local mounted
607794
data drive.
608795
If host == 'aws', data is downloaded from Amazon S3 Open
@@ -623,8 +810,8 @@ def download_data(self, links, host='heasarc', location='.'):
623810
if isinstance(links, Row):
624811
links = links.table[[links.index]]
625812

626-
if host not in ['heasarc', 'sciserver', 'aws']:
627-
raise ValueError('host has to be one of heasarc, sciserver, aws')
813+
# guess the host if not provided
814+
host = self._guess_host(host)
628815

629816
host_column = 'access_url' if host == 'heasarc' else host
630817
if host_column not in links.colnames:

0 commit comments

Comments
 (0)