88import pytest
99from numpy .testing import assert_equal
1010
11- from photutils .utils ._quantity_helpers import isscalar , process_quantities
11+ from photutils .utils ._quantity_helpers import (check_units , isscalar ,
12+ process_quantities )
1213
1314
1415@pytest .mark .parametrize ('all_units' , [False , True ])
@@ -30,59 +31,75 @@ def test_units(all_units):
3031 assert arrs2 == arrs
3132
3233
34+ def test_process_quantities_all_none ():
35+ """
36+ Test that process_quantities with all None inputs returns None
37+ unit.
38+ """
39+ values , unit = process_quantities ([None , None ], ['a' , 'b' ])
40+ assert values == [None , None ]
41+ assert unit is None
42+
43+
44+ def test_isscalar ():
45+ """
46+ Test isscalar with scalar and array inputs.
47+ """
48+ assert isscalar (1 )
49+ assert isscalar (1.0 * u .m )
50+ assert not isscalar ([1 , 2 , 3 ])
51+ assert not isscalar ([1 , 2 , 3 ] * u .m )
52+
53+
54+ def test_inputs ():
55+ """
56+ Test that mismatched values and names lengths raises ValueError.
57+ """
58+ match = 'The number of values must match the number of names'
59+ with pytest .raises (ValueError , match = match ):
60+ process_quantities ([1 , 2 , 3 ], ['a' , 'b' ])
61+ with pytest .raises (ValueError , match = match ):
62+ check_units ([1 , 2 , 3 ], ['a' , 'b' ])
63+
64+
65+ def test_check_units ():
66+ """
67+ Test check_units for unit consistency checking.
68+ """
69+ # Valid: same units
70+ check_units ((np .ones (3 ) * u .Jy , np .ones (3 ) * u .Jy ), ('a' , 'b' ))
71+
72+ # Valid: no units
73+ check_units ((np .ones (3 ), np .ones (3 )), ('a' , 'b' ))
74+
75+ # Valid: with None values
76+ check_units ((np .ones (3 ) * u .Jy , None ), ('a' , 'b' ))
77+
78+
3379def test_mixed_units ():
3480 """
35- Test that process_quantities with mixed units raises ValueError.
81+ Test that check_units with mixed units raises ValueError.
3682 """
3783 arrs = (np .ones (3 ) * u .Jy , np .ones (3 ) * u .km )
3884 names = ('a' , 'b' )
3985
4086 match = 'must all have the same units'
4187 with pytest .raises (ValueError , match = match ):
42- _ , _ = process_quantities (arrs , names )
88+ check_units (arrs , names )
4389
4490 arrs = (np .ones (3 ) * u .Jy , np .ones (3 ))
4591 names = ('a' , 'b' )
4692 with pytest .raises (ValueError , match = match ):
47- _ , _ = process_quantities (arrs , names )
93+ check_units (arrs , names )
4894
4995 unit = u .Jy
5096 arrs = (np .ones (3 ) * unit , np .ones (3 ), np .ones (3 ) * unit )
5197 names = ('a' , 'b' , 'c' )
5298 with pytest .raises (ValueError , match = match ):
53- _ , _ = process_quantities (arrs , names )
99+ check_units (arrs , names )
54100
55101 unit = u .Jy
56102 arrs = (np .ones (3 ) * unit , np .ones (3 ), np .ones (3 ) * u .km )
57103 names = ('a' , 'b' , 'c' )
58104 with pytest .raises (ValueError , match = match ):
59- _ , _ = process_quantities (arrs , names )
60-
61-
62- def test_process_quantities_all_none ():
63- """
64- Test that process_quantities with all None inputs returns None
65- unit.
66- """
67- values , unit = process_quantities ([None , None ], ['a' , 'b' ])
68- assert values == [None , None ]
69- assert unit is None
70-
71-
72- def test_inputs ():
73- """
74- Test that mismatched values and names lengths raises ValueError.
75- """
76- match = 'The number of values must match the number of names'
77- with pytest .raises (ValueError , match = match ):
78- _ , _ = process_quantities ([1 , 2 , 3 ], ['a' , 'b' ])
79-
80-
81- def test_isscalar ():
82- """
83- Test isscalar with scalar and array inputs.
84- """
85- assert isscalar (1 )
86- assert isscalar (1.0 * u .m )
87- assert not isscalar ([1 , 2 , 3 ])
88- assert not isscalar ([1 , 2 , 3 ] * u .m )
105+ check_units (arrs , names )
0 commit comments