Skip to content

Commit c3c56d8

Browse files
authored
Merge pull request #591 from astrofrog/cleanup-api
2 parents 32a17fb + 7d5a0d9 commit c3c56d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+4819
-4530
lines changed

reproject/_array_utils.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import numpy as np
2+
from dask_image.ndinterp import map_coordinates as dask_image_map_coordinates
3+
from dask_image.ndinterp import spline_filter
4+
from scipy.ndimage import spline_filter as scipy_spline_filter
5+
6+
__all__ = ["map_coordinates", "dask_map_coordinates", "sample_array_edges", "ArrayWrapper"]
7+
8+
9+
def find_chunk_shape(shape, max_chunk_size=None):
10+
"""
11+
Given the shape of an n-dimensional array, and the maximum number of
12+
elements in a chunk, return the largest chunk shape to use for iteration.
13+
14+
This currently assumes the optimal chunk shape to return is for C-contiguous
15+
arrays.
16+
17+
Parameters
18+
----------
19+
shape : iterable
20+
The shape of the n-dimensional array.
21+
max_chunk_size : int, optional
22+
The maximum number of elements per chunk.
23+
"""
24+
25+
if max_chunk_size is None:
26+
return tuple(shape)
27+
28+
block_shape = []
29+
30+
max_repeat_remaining = max_chunk_size
31+
32+
for size in shape[::-1]:
33+
if max_repeat_remaining > size:
34+
block_shape.append(size)
35+
max_repeat_remaining = max_repeat_remaining // size
36+
else:
37+
block_shape.append(max_repeat_remaining)
38+
max_repeat_remaining = 1
39+
40+
return tuple(block_shape[::-1])
41+
42+
43+
def iterate_chunks(shape, *, max_chunk_size):
44+
"""
45+
Given a data shape and a chunk shape (or maximum chunk size), iteratively
46+
return slice objects that can be used to slice the array.
47+
48+
Parameters
49+
----------
50+
shape : iterable
51+
The shape of the n-dimensional array.
52+
max_chunk_size : int
53+
The maximum number of elements per chunk.
54+
"""
55+
56+
if np.prod(shape) == 0:
57+
return
58+
59+
chunk_shape = find_chunk_shape(shape, max_chunk_size)
60+
61+
ndim = len(chunk_shape)
62+
start_index = [0] * ndim
63+
64+
shape = list(shape)
65+
66+
while start_index <= shape:
67+
end_index = [min(start_index[i] + chunk_shape[i], shape[i]) for i in range(ndim)]
68+
69+
slices = tuple([slice(start_index[i], end_index[i]) for i in range(ndim)])
70+
71+
yield slices
72+
73+
# Update chunk index. What we do is to increment the
74+
# counter for the first dimension, and then if it
75+
# exceeds the number of elements in that direction,
76+
# cycle back to zero and advance in the next dimension,
77+
# and so on.
78+
start_index[0] += chunk_shape[0]
79+
for i in range(ndim - 1):
80+
if start_index[i] >= shape[i]:
81+
start_index[i] = 0
82+
start_index[i + 1] += chunk_shape[i + 1]
83+
84+
# We can now check whether the iteration is finished
85+
if start_index[-1] >= shape[-1]:
86+
break
87+
88+
89+
def at_least_float32(array):
90+
if array.dtype.kind == "f" and array.dtype.itemsize >= 4:
91+
return array
92+
else:
93+
return array.astype(np.float32)
94+
95+
96+
def memory_efficient_access(array, chunk):
97+
# If we access a number of chunks from a memory-mapped array, memory usage
98+
# will increase and could crash e.g. dask.distributed workers. We therefore
99+
# use a temporary memmap to load the data.
100+
if isinstance(array, np.memmap) and array.flags.c_contiguous:
101+
array_tmp = np.memmap(
102+
array.filename,
103+
mode="r",
104+
dtype=array.dtype,
105+
shape=array.shape,
106+
offset=array.offset,
107+
)
108+
return array_tmp[chunk]
109+
else:
110+
return array[chunk]
111+
112+
113+
def _clip_coords(image, coords):
114+
115+
shape = image.shape
116+
117+
coords = coords.copy()
118+
for i in range(coords.shape[0]):
119+
coords[i][(coords[i] < 0) & (coords[i] >= -0.5)] = 0
120+
coords[i][(coords[i] < shape[i] - 0.5) & (coords[i] >= shape[i] - 1)] = shape[i] - 1
121+
122+
return coords
123+
124+
125+
def dask_map_coordinates(image, coords, output=None, **kwargs):
126+
127+
cval = kwargs.get("cval", 0.0)
128+
129+
original_shape = image.shape
130+
131+
# Thin wrapper around dask-image's map_coordinates which ensures that we can
132+
# interpolate right to the edge of the image, and also implement the output
133+
# keyword argument
134+
135+
coords = _clip_coords(image, coords)
136+
137+
if output is None:
138+
output = np.ones(coords.shape[1]) * cval
139+
else:
140+
output[:] = cval
141+
142+
# At the time of writing, dask-image is not able to correctly handle
143+
# prefiltering, instead doing it per-chunk which can give subtly different
144+
# results
145+
if kwargs["order"] >= 2:
146+
try:
147+
image = spline_filter(image, order=kwargs["order"], mode="constant")
148+
except ValueError as exc:
149+
# If arrays are too small, spline_filter can fail, so we catch this
150+
# case and call the scipy version if so
151+
if "The overlapping depth" in str(exc):
152+
image = scipy_spline_filter(image, order=kwargs["order"], mode="constant")
153+
else:
154+
raise exc
155+
156+
# dask-image's map_coordinates will crash if NaN values are passed in
157+
# coords, so we filter these out (this is a good idea anyway for performance)
158+
keep = ~np.any(np.isnan(coords), axis=0)
159+
160+
# At the time of writing, dask-image's map_coordinates prefilter is False
161+
# by default, we hard-code this here to guard against any changes in
162+
# default
163+
164+
output[keep] = dask_image_map_coordinates(
165+
image, coords[:, keep], prefilter=False, **kwargs
166+
).compute()
167+
168+
reset = np.zeros(coords.shape[1], dtype=bool)
169+
170+
for i in range(coords.shape[0]):
171+
reset |= coords[i] < -0.5
172+
reset |= coords[i] > original_shape[i] - 0.5
173+
174+
output[reset] = cval
175+
176+
return output
177+
178+
179+
def map_coordinates(
180+
image, coords, max_chunk_size=None, output=None, optimize_memory=False, **kwargs
181+
):
182+
# In the built-in scipy map_coordinates, the values are defined at the
183+
# center of the pixels. This means that map_coordinates does not
184+
# correctly treat pixels that are in the outer half of the outer pixels.
185+
# We solve this by resetting any coordinates that are in the outer half of
186+
# the border pixels to be at the center of the border pixels. We used to
187+
# instead pad the array but this was not memory efficient as it ended up
188+
# producing a copy of the output array.
189+
190+
# In addition, map_coordinates is not efficient when given big-endian Numpy
191+
# arrays as it will then make a copy, which is an issue when dealing with
192+
# memory-mapped FITS files that might be larger than memory. Therefore, for
193+
# big-endian arrays, we operate in chunks with a size smaller or equal to
194+
# max_chunk_size.
195+
196+
# The optimize_memory option isn't used right not by the rest of reproject
197+
# but it is a mode where if we are in a memory-constrained environment, we
198+
# re-create memmaps for individual chunks to avoid caching the whole array.
199+
# We need to decide how to expose this to users.
200+
201+
# TODO: check how this should behave on a big-endian system.
202+
203+
from scipy.ndimage import map_coordinates as scipy_map_coordinates
204+
205+
original_shape = image.shape
206+
207+
# We copy the coordinates array as we then modify it in-place below to clip
208+
# to the edges of the array.
209+
210+
coords = _clip_coords(image, coords)
211+
212+
# If the data type is native and we are not doing spline interpolation,
213+
# then scipy_map_coordinates deals properly with memory maps, so we can use
214+
# it without chunking. Otherwise, we need to iterate over data chunks.
215+
if image.dtype.isnative and "order" in kwargs and kwargs["order"] <= 1:
216+
values = scipy_map_coordinates(at_least_float32(image), coords, output=output, **kwargs)
217+
else:
218+
if output is None:
219+
output = np.repeat(np.nan, coords.shape[1])
220+
221+
values = output
222+
223+
include = np.ones(coords.shape[1], dtype=bool)
224+
225+
if "order" in kwargs and kwargs["order"] <= 1:
226+
padding = 1
227+
else:
228+
padding = 10
229+
230+
for chunk in iterate_chunks(image.shape, max_chunk_size=max_chunk_size):
231+
232+
include[...] = True
233+
for idim, slc in enumerate(chunk):
234+
include[(coords[idim] < slc.start) | (coords[idim] >= slc.stop)] = False
235+
236+
if not np.any(include):
237+
continue
238+
239+
chunk = list(chunk)
240+
241+
# Adjust chunks to add padding
242+
for idim, slc in enumerate(chunk):
243+
start = max(0, slc.start - padding)
244+
stop = min(original_shape[idim], slc.stop + padding)
245+
chunk[idim] = slice(start, stop)
246+
247+
chunk = tuple(chunk)
248+
249+
coords_subset = coords[:, include].copy()
250+
for idim, slc in enumerate(chunk):
251+
coords_subset[idim, :] -= slc.start
252+
253+
if optimize_memory:
254+
image_subset = memory_efficient_access(image, chunk)
255+
else:
256+
image_subset = image[chunk]
257+
258+
output[include] = scipy_map_coordinates(
259+
at_least_float32(image_subset), coords_subset, **kwargs
260+
)
261+
262+
reset = np.zeros(coords.shape[1], dtype=bool)
263+
264+
for i in range(coords.shape[0]):
265+
reset |= coords[i] < -0.5
266+
reset |= coords[i] > original_shape[i] - 0.5
267+
268+
values[reset] = kwargs.get("cval", 0.0)
269+
270+
return values
271+
272+
273+
def sample_array_edges(shape, *, n_samples):
274+
# Given an N-dimensional array shape, sample each edge of the array using
275+
# the requested number of samples (which will include vertices). To do this
276+
# we iterate through the dimensions and for each one we sample the points
277+
# in that dimension and iterate over the combination of other vertices.
278+
# Returns an array with dimensions (N, n_samples)
279+
all_positions = []
280+
ndim = len(shape)
281+
shape = np.array(shape)
282+
for idim in range(ndim):
283+
for vertex in range(2**ndim):
284+
positions = -0.5 + shape * ((vertex & (2 ** np.arange(ndim))) > 0).astype(int)
285+
positions = np.broadcast_to(positions, (n_samples, ndim)).copy()
286+
positions[:, idim] = np.linspace(-0.5, shape[idim] - 0.5, n_samples)
287+
all_positions.append(positions)
288+
positions = np.unique(np.vstack(all_positions), axis=0).T
289+
return positions
290+
291+
292+
class ArrayWrapper:
293+
294+
def __init__(self, array):
295+
self._array = array
296+
self.ndim = array.ndim
297+
self.shape = array.shape
298+
self.dtype = array.dtype
299+
300+
def __getitem__(self, item):
301+
return self._array[item]

0 commit comments

Comments
 (0)