Skip to content

Commit 56b5053

Browse files
committed
feat: Implements on-demand coordinate system conversion and geometry retrieval
Moves coordinate system transformation and deobliquing from the MedVol constructor to dedicated accessor methods. This allows users to retrieve data and metadata in arbitrary orientations (e.g., LPS+, ASR+) without modifying the internal RAS+ canonical state. Key changes: - Adds `get_geometry` to return spatial metadata converted to a requested coordinate system. - Adds `get_array` to return a reoriented, zero-copy view of the image data. - Implements robust anatomical coordinate system parsing for R/L, A/P, and S/I axes. - Removes `remove_obliqueness` from the constructor in favor of the `deoblique` flag in `get_geometry`.
1 parent f645528 commit 56b5053

3 files changed

Lines changed: 341 additions & 22 deletions

File tree

medvol/core.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SNAP_ATOL,
1010
SIMPLEITK_AXIS_LABELS,
1111
CoordinateContext,
12+
_is_signed_permutation,
1213
affine_to_rotation,
1314
affine_to_shear,
1415
canonical_coordinate_context,
@@ -18,6 +19,8 @@
1819
deoblique_affine,
1920
decompose_affine,
2021
normalize_backend_name,
22+
parse_coordinate_system,
23+
snap_values,
2124
validate_affine,
2225
)
2326
from medvol.registry import get_backend, resolve_backend
@@ -61,16 +64,11 @@ def __init__(
6164
coordinate_system: str | None = None,
6265
backend: str | None = None,
6366
canonicalize: bool = True,
64-
remove_obliqueness: bool = False,
6567
) -> None:
66-
if remove_obliqueness and not canonicalize:
67-
raise ValueError("remove_obliqueness requires canonicalize=True.")
68-
6968
self._coordinate_context = None
7069
self._header = None
7170
self._backend = normalize_backend_name(backend)
7271
self._canonicalize = canonicalize
73-
self._remove_obliqueness = remove_obliqueness
7472

7573
if isinstance(source, (str, Path)):
7674
if coordinate_system is not None:
@@ -119,8 +117,6 @@ def _apply_orientation_policy(self) -> None:
119117
self._coordinate_context,
120118
atol=SNAP_ATOL,
121119
)
122-
if self._remove_obliqueness:
123-
self._affine = deoblique_affine(self._affine, atol=SNAP_ATOL)
124120

125121
@staticmethod
126122
def _validate_array(array: np.ndarray) -> np.ndarray:
@@ -228,6 +224,135 @@ def header(self, value: Any) -> None:
228224
def backend(self) -> str | None:
229225
return self._backend
230226

227+
def get_geometry(
228+
self,
229+
coordinate_system: str = "RAS+",
230+
*,
231+
deoblique: bool = False,
232+
) -> dict:
233+
"""Return geometry converted to the requested coordinate system.
234+
235+
The internal state is not modified — only the returned values are
236+
converted. Requires ``canonicalize=True`` (raises ``ValueError``
237+
otherwise).
238+
239+
Args:
240+
coordinate_system: Target coordinate system, e.g. ``"RAS+"``,
241+
``"LPS+"``, ``"ASR+"``. Must contain exactly
242+
``spatial_ndim`` anatomical letters (R/L, A/P, S/I), each
243+
from a different anatomical axis, with an optional trailing
244+
``"+"``.
245+
deoblique: If ``True``, strip off-diagonal entries from the
246+
returned affine (diagonal affine, keeps origin). Equivalent
247+
to calling ``get_geometry(deoblique=False)`` and then
248+
removing the oblique component.
249+
250+
Returns:
251+
Dict with keys:
252+
253+
* ``"affine"`` — (ndim+1)×(ndim+1) affine in the target system.
254+
* ``"spacing"`` — always-positive voxel spacing (column norms).
255+
* ``"origin"`` — world coordinates of voxel (0, 0, …, 0).
256+
* ``"direction"`` — unit-column direction cosine matrix.
257+
* ``"coordinate_system"`` — the *coordinate_system* argument.
258+
* ``"oblique"`` — ``True`` when the spatial direction block is
259+
not a signed permutation matrix (i.e. the image is oblique).
260+
261+
Raises:
262+
ValueError: If ``canonicalize=False`` or the coordinate context
263+
is unknown.
264+
"""
265+
if not self._canonicalize:
266+
raise ValueError("get_geometry requires canonicalize=True.")
267+
if self._coordinate_context is None:
268+
raise ValueError(
269+
"get_geometry requires a known coordinate context. "
270+
"Load from a file with a recognised coordinate system."
271+
)
272+
273+
spatial_ndim = self._coordinate_context.anatomical_ndim
274+
axis_order, flips = parse_coordinate_system(coordinate_system, spatial_ndim)
275+
signs = [-1 if f else 1 for f in flips]
276+
277+
A = self._affine
278+
shape = self._array.shape
279+
280+
# ── Step 1: permute and sign spatial columns (data-axis transform) ──
281+
# Read from original A; write to A_mid so we never clobber a source col.
282+
A_mid = A.copy()
283+
for m in range(spatial_ndim):
284+
A_mid[:, m] = signs[m] * A[:, axis_order[m]]
285+
# Adjust translation column for flipped axes:
286+
# flip on axis m maps voxel i → (N-1-i), shifting the origin to the far corner.
287+
for m in range(spatial_ndim):
288+
if flips[m]:
289+
A_mid[:, -1] += A[:, axis_order[m]] * (shape[axis_order[m]] - 1)
290+
291+
# ── Step 2: permute and sign spatial rows (world-basis transform) ──
292+
A_final = A_mid.copy()
293+
for m in range(spatial_ndim):
294+
A_final[m, :] = signs[m] * A_mid[axis_order[m], :]
295+
296+
A_final = snap_values(A_final)
297+
298+
if deoblique:
299+
A_final = deoblique_affine(A_final)
300+
301+
spacing, origin, direction = decompose_affine(A_final)
302+
303+
# Oblique iff the spatial direction block is not a signed permutation.
304+
spatial_dir = direction[:spatial_ndim, :spatial_ndim]
305+
is_oblique = not _is_signed_permutation(spatial_dir)
306+
307+
return {
308+
"affine": A_final,
309+
"spacing": spacing,
310+
"origin": origin,
311+
"direction": direction,
312+
"coordinate_system": coordinate_system,
313+
"oblique": is_oblique,
314+
}
315+
316+
def get_array(self, coordinate_system: str = "RAS+") -> np.ndarray:
317+
"""Return the array converted to the requested coordinate system.
318+
319+
Returns a zero-copy NumPy view — no interpolation is performed.
320+
Only axis permutations and flips are applied. Requires
321+
``canonicalize=True`` (raises ``ValueError`` otherwise).
322+
323+
Args:
324+
coordinate_system: Target coordinate system string (see
325+
``get_geometry`` for the accepted format).
326+
327+
Returns:
328+
NumPy array in the requested axis order and orientation.
329+
Non-spatial axes (e.g. time for 4-D images) are appended
330+
unchanged at the end.
331+
332+
Raises:
333+
ValueError: If ``canonicalize=False`` or the coordinate context
334+
is unknown.
335+
"""
336+
if not self._canonicalize:
337+
raise ValueError("get_array requires canonicalize=True.")
338+
if self._coordinate_context is None:
339+
raise ValueError(
340+
"get_array requires a known coordinate context."
341+
)
342+
343+
spatial_ndim = self._coordinate_context.anatomical_ndim
344+
axis_order, flips = parse_coordinate_system(coordinate_system, spatial_ndim)
345+
346+
# Non-spatial axes (e.g. time) stay at the end, in their original order.
347+
full_axis_order = list(axis_order) + list(range(spatial_ndim, self.ndims))
348+
result = np.transpose(self._array, full_axis_order)
349+
350+
for m, flip in enumerate(flips):
351+
if flip:
352+
result = np.flip(result, axis=m)
353+
354+
return result
355+
231356
def save(self, filepath: str | Path, *, backend: str | None = None) -> None:
232357
resolved_backend = resolve_backend(filepath, backend)
233358
get_backend(resolved_backend).save(Path(filepath), self)

medvol/geometry.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,85 @@ def is_homogeneous_row(row: Iterable[float], *, atol: float = SNAP_ATOL) -> bool
374374
expected = np.zeros_like(row_array)
375375
expected[-1] = 1.0
376376
return np.allclose(row_array, expected, atol=atol)
377+
378+
379+
# Maps each anatomical letter to (ras_axis_index, flip).
380+
# ras_axis_index: which of the 3 RAS axes (R=0, A=1, S=2) this letter belongs to.
381+
# flip: True when the letter denotes the negative direction of that axis (L, P, I).
382+
_AXIS_LETTER_MAP: dict[str, tuple[int, bool]] = {
383+
"R": (0, False),
384+
"L": (0, True),
385+
"A": (1, False),
386+
"P": (1, True),
387+
"S": (2, False),
388+
"I": (2, True),
389+
}
390+
391+
392+
def parse_coordinate_system(
393+
cs_str: str,
394+
spatial_ndim: int = 3,
395+
) -> tuple[list[int], list[bool]]:
396+
"""Parse a coordinate system string into (axis_order, flips) relative to RAS+.
397+
398+
Args:
399+
cs_str: Coordinate system string, e.g. "LPS+", "ASR+", "RAS".
400+
Must contain exactly ``spatial_ndim`` anatomical letters
401+
(R/L, A/P, S/I), each from a different anatomical axis, with an
402+
optional trailing "+" or "−".
403+
spatial_ndim: Number of expected spatial axes (2 or 3).
404+
405+
Returns:
406+
axis_order: ``axis_order[m]`` is the index (0=R/L, 1=A/P, 2=S/I) of the
407+
RAS+ axis that corresponds to output axis ``m``.
408+
flips: ``flips[m]`` is True when output axis ``m`` runs in the negative
409+
direction of the corresponding RAS+ axis (e.g. L, P, I).
410+
411+
Raises:
412+
ValueError: For unknown letters, wrong length, or duplicate axes.
413+
"""
414+
cs = cs_str.rstrip("+-").upper()
415+
if len(cs) != spatial_ndim:
416+
raise ValueError(
417+
f"Coordinate system {cs_str!r} must have {spatial_ndim} anatomical "
418+
f"letters (got {len(cs)})."
419+
)
420+
421+
axis_order: list[int] = []
422+
flips: list[bool] = []
423+
used_axes: set[int] = set()
424+
425+
for letter in cs:
426+
if letter not in _AXIS_LETTER_MAP:
427+
raise ValueError(
428+
f"Unknown axis letter {letter!r} in coordinate system {cs_str!r}. "
429+
"Valid letters: R, L, A, P, S, I."
430+
)
431+
ras_axis, flip = _AXIS_LETTER_MAP[letter]
432+
if ras_axis in used_axes:
433+
raise ValueError(
434+
f"Duplicate anatomical axis in coordinate system {cs_str!r}."
435+
)
436+
used_axes.add(ras_axis)
437+
axis_order.append(ras_axis)
438+
flips.append(flip)
439+
440+
return axis_order, flips
441+
442+
443+
def _is_signed_permutation(matrix: np.ndarray, *, atol: float = SNAP_ATOL) -> bool:
444+
"""Return True if *matrix* is approximately a signed permutation matrix.
445+
446+
A signed permutation matrix has exactly one entry ±1 per row and per column
447+
and zeros elsewhere (up to *atol*).
448+
"""
449+
m = np.asarray(matrix, dtype=float)
450+
abs_m = np.abs(m)
451+
# Every entry must be close to 0 or 1.
452+
if not np.all((abs_m < atol) | (np.abs(abs_m - 1.0) < atol)):
453+
return False
454+
# Each row and each column must have exactly one nonzero entry.
455+
return bool(
456+
np.all(np.sum(abs_m > atol, axis=0) == 1)
457+
and np.all(np.sum(abs_m > atol, axis=1) == 1)
458+
)

0 commit comments

Comments
 (0)