Skip to content

Commit fc0dd18

Browse files
committed
Changed StatsMixin to be a child of MathMixin, added argmin and argmax functions
1 parent fad6337 commit fc0dd18

5 files changed

Lines changed: 27 additions & 8 deletions

File tree

src/lkdata/datacube.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .dataseries import DataSeries, BoolSeries, BitwiseSeries
1717
from .mixins import (
1818
StatsMixin,
19-
MathMixin,
2019
BoolMixin,
2120
BitwiseMixin,
2221
AggMixin,
@@ -526,7 +525,6 @@ def values(self):
526525

527526

528527
class DataCube(
529-
MathMixin,
530528
StatsMixin,
531529
Cube,
532530
):

src/lkdata/dataseries.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .mixins import (
1111
AggMixin,
1212
ConvenienceMixins,
13-
MathMixin,
1413
StatsMixin,
1514
BoolMixin,
1615
BitwiseMixin,
@@ -134,7 +133,7 @@ def stats_post_process(self, result, **kwargs):
134133
return result
135134

136135

137-
class DataSeries(MathMixin, StatsMixin, Series):
136+
class DataSeries(StatsMixin, Series):
138137
"""
139138
pandas.Series-like object with uncertainty and lightkurve functionality.
140139

src/lkdata/mixins.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"sum",
3838
"std",
3939
"var",
40+
"argmin",
41+
"argmax",
4042
"min",
4143
"max",
4244
"prod",
@@ -777,7 +779,7 @@ def uncertainty(self, value):
777779
self._uncertainty = uncertainty
778780

779781

780-
class StatsMixin:
782+
class StatsMixin(MathMixin):
781783
"""Defines a mixin class which will let us postprocess all our pandas stats"""
782784

783785
_stats_type = "data"
@@ -795,6 +797,11 @@ def _create_stats_method(self, method_name):
795797
def _method(*args, **kwargs):
796798
axis = kwargs.pop("axis", None)
797799
np_method = getattr(np, method_name)
800+
if np_method in (np.argmin, np.argmax):
801+
if axis is None:
802+
# returning unravelled index instead of flattened index
803+
return np.unravel_index(np_method(self.array), self.array.shape)
804+
return np_method(self.to_numpy(), axis=axis)
798805
result, init_kwds = self._arithmetic(
799806
np_method, operand=None, data_axis=axis, uncertainty_axis=axis, **kwargs
800807
)

src/lkdata/seriescollection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .mixins import (
1414
AggMixin,
1515
ConvenienceMixins,
16-
MathMixin,
1716
StatsMixin,
1817
BoolMixin,
1918
BitwiseMixin,
@@ -217,7 +216,7 @@ def stats_post_process(self, result, **kwargs):
217216
return result
218217

219218

220-
class DataSeriesCollection(MathMixin, StatsMixin, SeriesCollection):
219+
class DataSeriesCollection(StatsMixin, SeriesCollection):
221220
_series_class = DataSeries
222221

223222
def __init__(

tests/test_datacube.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,28 @@ def test_setup():
3838

3939
# Test overridden pandas methods return correct shapes
4040
# Data products return tuples for and uncertainty
41-
for method_name in STATS_METHOD_NAMES:
41+
methods = [
42+
method for method in STATS_METHOD_NAMES if method not in ["argmin", "argmax"]
43+
]
44+
for method_name in methods:
4245
assert getattr(df, method_name)(axis=0)[0].shape == (nrow, ncol)
4346
assert (
4447
getattr(df[:, aperture], method_name)(axis=0)[0].shape[0] == aperture.sum()
4548
)
4649

50+
# argmin and argmax should return a single 3D index for axis=None
51+
assert len(df.argmin()) == 3
52+
assert len(df.argmax()) == 3
53+
54+
# argmin and argmax operate on axis=0 or 1 only
55+
assert df.argmin(axis=0).shape[0] == nrow * ncol
56+
assert df.argmax(axis=0).shape[0] == nrow * ncol
57+
assert df.argmin(axis=1).shape[0] == ntime
58+
assert df.argmax(axis=1).shape[0] == ntime
59+
60+
with pytest.raises(np.AxisError, match="axis 2 is out of bounds"):
61+
_ = df.argmin(axis=2)
62+
4763

4864
def test_bad_setup():
4965
# Actual data values should be irrelevant for these tests

0 commit comments

Comments
 (0)