Skip to content

Commit 1a07a83

Browse files
authored
Merge pull request #2245 from larrybradley/check-units
Add check_units (private) helper function
2 parents 1eb33bd + 7019d84 commit 1a07a83

File tree

7 files changed

+120
-80
lines changed

7 files changed

+120
-80
lines changed

photutils/detection/core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
deprecated_positional_kwargs,
2525
deprecated_renamed_argument)
2626
from photutils.utils._misc import _get_meta
27-
from photutils.utils._quantity_helpers import process_quantities
27+
from photutils.utils._quantity_helpers import check_units
2828
from photutils.utils._repr import make_repr
2929
from photutils.utils.cutouts import _make_cutouts
3030
from photutils.utils.exceptions import NoDetectionsWarning
@@ -221,10 +221,8 @@ class StarFinderCatalogBase(metaclass=abc.ABCMeta):
221221
@deprecated_renamed_argument('peakmax', 'peak_max', '3.0', until='4.0')
222222
def __init__(self, data, xypos, kernel, *, n_brightest=None,
223223
peak_max=None):
224-
# Validate the units, but do not strip them
225-
inputs = (data, peak_max)
226-
names = ('data', 'peak_max')
227-
_ = process_quantities(inputs, names)
224+
# Validate the units
225+
check_units((data, peak_max), ('data', 'peak_max'))
228226

229227
self.data = data
230228
unit = data.unit if isinstance(data, u.Quantity) else None

photutils/detection/daofinder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from photutils.utils._convolution import _filter_data
1717
from photutils.utils._deprecation import (deprecated_positional_kwargs,
1818
deprecated_renamed_argument)
19-
from photutils.utils._quantity_helpers import isscalar, process_quantities
19+
from photutils.utils._quantity_helpers import check_units, isscalar
2020
from photutils.utils._repr import make_repr
2121
from photutils.utils.exceptions import NoDetectionsWarning
2222

@@ -219,7 +219,7 @@ def __init__(self, threshold, fwhm, ratio=1.0, theta=0.0,
219219
# Validate the units, but do not strip them
220220
inputs = (threshold, peak_max)
221221
names = ('threshold', 'peak_max')
222-
_ = process_quantities(inputs, names)
222+
check_units(inputs, names)
223223

224224
if not isscalar(fwhm):
225225
msg = 'fwhm must be a scalar value'
@@ -399,7 +399,7 @@ def find_stars(self, data, mask=None):
399399
# Validate the units, but do not strip them
400400
inputs = (data, self.threshold, self.peak_max)
401401
names = ('data', 'threshold', 'peak_max')
402-
_ = process_quantities(inputs, names)
402+
check_units(inputs, names)
403403

404404
cat = self._get_raw_catalog(data, mask=mask)
405405
if cat is None:
@@ -476,7 +476,7 @@ def __init__(self, data, convolved_data, xypos, threshold, kernel, *,
476476
# Validate the units, but do not strip them
477477
inputs = (data, convolved_data, threshold, peak_max)
478478
names = ('data', 'convolved_data', 'threshold', 'peak_max')
479-
_ = process_quantities(inputs, names)
479+
check_units(inputs, names)
480480

481481
super().__init__(data, xypos, kernel,
482482
n_brightest=n_brightest,

photutils/detection/irafstarfinder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from photutils.utils._convolution import _filter_data
1717
from photutils.utils._deprecation import (deprecated_positional_kwargs,
1818
deprecated_renamed_argument)
19-
from photutils.utils._quantity_helpers import isscalar, process_quantities
19+
from photutils.utils._quantity_helpers import check_units, isscalar
2020
from photutils.utils._repr import make_repr
2121
from photutils.utils.exceptions import NoDetectionsWarning
2222

@@ -184,7 +184,7 @@ def __init__(self, threshold, fwhm, sigma_radius=1.5,
184184
# Validate the units, but do not strip them
185185
inputs = (threshold, peak_max)
186186
names = ('threshold', 'peak_max')
187-
_ = process_quantities(inputs, names)
187+
check_units(inputs, names)
188188

189189
if not isscalar(fwhm):
190190
msg = 'fwhm must be a scalar value'
@@ -364,7 +364,7 @@ def find_stars(self, data, mask=None):
364364
"""
365365
inputs = (data, self.threshold, self.peak_max)
366366
names = ('data', 'threshold', 'peak_max')
367-
_ = process_quantities(inputs, names)
367+
check_units(inputs, names)
368368

369369
cat = self._get_raw_catalog(data, mask=mask)
370370
if cat is None:
@@ -434,7 +434,7 @@ def __init__(self, data, convolved_data, xypos, kernel, *,
434434
# Validate the units, but do not strip them
435435
inputs = (data, convolved_data, peak_max)
436436
names = ('data', 'convolved_data', 'peak_max')
437-
_ = process_quantities(inputs, names)
437+
check_units(inputs, names)
438438

439439
super().__init__(data, xypos, kernel,
440440
n_brightest=n_brightest,

photutils/detection/starfinder.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from photutils.utils._convolution import _filter_data
1414
from photutils.utils._deprecation import (deprecated_positional_kwargs,
1515
deprecated_renamed_argument)
16-
from photutils.utils._quantity_helpers import process_quantities
16+
from photutils.utils._quantity_helpers import check_units
1717
from photutils.utils._repr import make_repr
1818
from photutils.utils.exceptions import NoDetectionsWarning
1919

@@ -88,10 +88,8 @@ class StarFinder(StarFinderBase):
8888
def __init__(self, threshold, kernel, min_separation=None,
8989
exclude_border=False, n_brightest=None, peak_max=None):
9090

91-
# Validate the units, but do not strip them
92-
inputs = (threshold, peak_max)
93-
names = ('threshold', 'peak_max')
94-
_ = process_quantities(inputs, names)
91+
# Validate the units
92+
check_units((threshold, peak_max), ('threshold', 'peak_max'))
9593

9694
self.threshold = threshold
9795

@@ -210,10 +208,9 @@ def find_stars(self, data, mask=None):
210208
`None` is returned if no stars are found or no stars meet
211209
the peak_max criteria.
212210
"""
213-
# Validate the units, but do not strip them
214-
inputs = (data, self.threshold, self.peak_max)
215-
names = ('data', 'threshold', 'peak_max')
216-
_ = process_quantities(inputs, names)
211+
# Validate the units
212+
check_units((data, self.threshold, self.peak_max),
213+
('data', 'threshold', 'peak_max'))
217214

218215
cat = self._get_raw_catalog(data, mask=mask)
219216
if cat is None:

photutils/segmentation/detect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from photutils.utils._deprecation import deprecated_renamed_argument
1616
from photutils.utils._parameters import (SigmaClipSentinelDefault,
1717
create_default_sigmaclip)
18-
from photutils.utils._quantity_helpers import process_quantities
18+
from photutils.utils._quantity_helpers import check_units, process_quantities
1919
from photutils.utils._stats import nanmean, nanstd
2020
from photutils.utils.exceptions import NoDetectionsWarning
2121

@@ -366,7 +366,7 @@ class to detect and deblend sources in a single step.
366366
cmap=segm.make_cmap(seed=1234))
367367
plt.tight_layout()
368368
"""
369-
_ = process_quantities((data, threshold), ('data', 'threshold'))
369+
check_units((data, threshold), ('data', 'threshold'))
370370

371371
if (n_pixels <= 0) or (int(n_pixels) != n_pixels):
372372
msg = f'n_pixels must be a positive integer, got {n_pixels!r}'

photutils/utils/_quantity_helpers.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,9 @@
77
import numpy as np
88

99

10-
def process_quantities(values, names):
10+
def check_units(values, names):
1111
"""
12-
Check and remove units of input values.
13-
14-
If any of the input values have units then they all must have
15-
units and the units must be the same.
16-
17-
The returned values are the input values with units removed and the
18-
unit.
12+
Check that input values have consistent units.
1913
2014
Parameters
2115
----------
@@ -27,18 +21,14 @@ def process_quantities(values, names):
2721
2822
Returns
2923
-------
30-
values : list of scalar or `~numpy.ndarray`
31-
A list of values, where units have been removed.
32-
33-
unit : `~astropy.unit.Unit`
34-
The common unit for the input values. `None` will be returned
35-
if all the input values do not have units (including when all
36-
values are `None`).
24+
units : set
25+
The set of distinct units across all non-`None` values.
3726
3827
Raises
3928
------
4029
ValueError
41-
If the input values do not all have the same units.
30+
If the number of values does not match the number of names,
31+
or if the input values do not all have the same units.
4232
"""
4333
if len(values) != len(names):
4434
msg = 'The number of values must match the number of names.'
@@ -47,9 +37,9 @@ def process_quantities(values, names):
4737
all_units = {name: getattr(arr, 'unit', None)
4838
for arr, name in zip(values, names, strict=True)
4939
if arr is not None}
50-
unit = set(all_units.values())
40+
units = set(all_units.values())
5141

52-
if len(unit) > 1:
42+
if len(units) > 1:
5343
param_names = list(all_units.keys())
5444
msg = [f'The inputs {param_names} must all have the same units:']
5545
indent = ' ' * 4
@@ -61,13 +51,51 @@ def process_quantities(values, names):
6151
msg = '\n'.join(msg)
6252
raise ValueError(msg)
6353

64-
# When all values are None, all_units is empty; return unchanged
54+
return units
55+
56+
57+
def process_quantities(values, names):
58+
"""
59+
Check and remove units of input values.
60+
61+
If any of the input values have units then they all must have
62+
units and the units must be the same.
63+
64+
The returned values are the input values with units removed and the
65+
unit.
66+
67+
Parameters
68+
----------
69+
values : list of scalar, `~numpy.ndarray`, or `~astropy.units.Quantity`
70+
A list of values.
71+
72+
names : list of str
73+
A list of names corresponding to the input ``values``.
74+
75+
Returns
76+
-------
77+
values : list of scalar or `~numpy.ndarray`
78+
A list of values, where units have been removed.
79+
80+
unit : `~astropy.unit.Unit`
81+
The common unit for the input values. `None` will be returned
82+
if all the input values do not have units (including when all
83+
values are `None`).
84+
85+
Raises
86+
------
87+
ValueError
88+
If the input values do not all have the same units.
89+
"""
90+
units = check_units(values, names)
91+
92+
# When all values are None, the units set is empty; return unchanged
6593
# with unit=None
66-
if len(unit) == 0:
94+
if len(units) == 0:
6795
return values, None
6896

6997
# Extract the unit and remove it from the return values
70-
unit = unit.pop()
98+
unit = units.pop()
7199
if unit is not None:
72100
values = [val.value if val is not None else val for val in values]
73101

photutils/utils/tests/test_quantity_helpers.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99
from numpy.testing import assert_equal
1010

11-
from photutils.utils._quantity_helpers import isscalar, process_quantities
11+
from photutils.utils._quantity_helpers import (check_units, isscalar,
12+
process_quantities)
1213

1314

1415
@pytest.mark.parametrize('all_units', [False, True])
@@ -30,59 +31,75 @@ def test_units(all_units):
3031
assert arrs2 == arrs
3132

3233

34+
def test_process_quantities_all_none():
35+
"""
36+
Test that process_quantities with all None inputs returns None
37+
unit.
38+
"""
39+
values, unit = process_quantities([None, None], ['a', 'b'])
40+
assert values == [None, None]
41+
assert unit is None
42+
43+
44+
def test_isscalar():
45+
"""
46+
Test isscalar with scalar and array inputs.
47+
"""
48+
assert isscalar(1)
49+
assert isscalar(1.0 * u.m)
50+
assert not isscalar([1, 2, 3])
51+
assert not isscalar([1, 2, 3] * u.m)
52+
53+
54+
def test_inputs():
55+
"""
56+
Test that mismatched values and names lengths raises ValueError.
57+
"""
58+
match = 'The number of values must match the number of names'
59+
with pytest.raises(ValueError, match=match):
60+
process_quantities([1, 2, 3], ['a', 'b'])
61+
with pytest.raises(ValueError, match=match):
62+
check_units([1, 2, 3], ['a', 'b'])
63+
64+
65+
def test_check_units():
66+
"""
67+
Test check_units for unit consistency checking.
68+
"""
69+
# Valid: same units
70+
check_units((np.ones(3) * u.Jy, np.ones(3) * u.Jy), ('a', 'b'))
71+
72+
# Valid: no units
73+
check_units((np.ones(3), np.ones(3)), ('a', 'b'))
74+
75+
# Valid: with None values
76+
check_units((np.ones(3) * u.Jy, None), ('a', 'b'))
77+
78+
3379
def test_mixed_units():
3480
"""
35-
Test that process_quantities with mixed units raises ValueError.
81+
Test that check_units with mixed units raises ValueError.
3682
"""
3783
arrs = (np.ones(3) * u.Jy, np.ones(3) * u.km)
3884
names = ('a', 'b')
3985

4086
match = 'must all have the same units'
4187
with pytest.raises(ValueError, match=match):
42-
_, _ = process_quantities(arrs, names)
88+
check_units(arrs, names)
4389

4490
arrs = (np.ones(3) * u.Jy, np.ones(3))
4591
names = ('a', 'b')
4692
with pytest.raises(ValueError, match=match):
47-
_, _ = process_quantities(arrs, names)
93+
check_units(arrs, names)
4894

4995
unit = u.Jy
5096
arrs = (np.ones(3) * unit, np.ones(3), np.ones(3) * unit)
5197
names = ('a', 'b', 'c')
5298
with pytest.raises(ValueError, match=match):
53-
_, _ = process_quantities(arrs, names)
99+
check_units(arrs, names)
54100

55101
unit = u.Jy
56102
arrs = (np.ones(3) * unit, np.ones(3), np.ones(3) * u.km)
57103
names = ('a', 'b', 'c')
58104
with pytest.raises(ValueError, match=match):
59-
_, _ = process_quantities(arrs, names)
60-
61-
62-
def test_process_quantities_all_none():
63-
"""
64-
Test that process_quantities with all None inputs returns None
65-
unit.
66-
"""
67-
values, unit = process_quantities([None, None], ['a', 'b'])
68-
assert values == [None, None]
69-
assert unit is None
70-
71-
72-
def test_inputs():
73-
"""
74-
Test that mismatched values and names lengths raises ValueError.
75-
"""
76-
match = 'The number of values must match the number of names'
77-
with pytest.raises(ValueError, match=match):
78-
_, _ = process_quantities([1, 2, 3], ['a', 'b'])
79-
80-
81-
def test_isscalar():
82-
"""
83-
Test isscalar with scalar and array inputs.
84-
"""
85-
assert isscalar(1)
86-
assert isscalar(1.0 * u.m)
87-
assert not isscalar([1, 2, 3])
88-
assert not isscalar([1, 2, 3] * u.m)
105+
check_units(arrs, names)

0 commit comments

Comments
 (0)