Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/access_nri_intake/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ def _normalise_value(self, field: str, value: str | Collection[str] | Any) -> An

# list/tuple/set of strings
if isinstance(value, (list | tuple | set)):
out = set()
for v in value:
normalized = aliases_for_field.get(v, v)
out.add(v)
[out.add(n) for n in normalized] # type: ignore[func-returns-value]
out = _normalize_list_vals(
value, aliases_for_field, field, self.show_warnings
)

# If any aliasing occurred, issue a warning showing the original and aliased values
if len(out) > len(value) and self.show_warnings:
Expand Down Expand Up @@ -286,17 +284,9 @@ def _normalise_value(self, field: str, value: str | Collection[str] | Any) -> An

# list/tuple/set of strings
if isinstance(value, (list | tuple | set)):
out = set()
for v in value:
normalized = aliases_for_field.get(v, v)
if normalized != v and self.show_warnings:
warnings.warn(
message=f"Value aliasing: {field}='{v}' → {field}=['{','.join(n for n in normalized)}','{v}']",
category=UserWarning,
stacklevel=4,
)
out.add(v)
[out.add(n) for n in normalized] # type: ignore[func-returns-value]
out = _normalize_list_vals(
value, aliases_for_field, field, self.show_warnings
)
return type(value)(out)

# anything else (regex, callable, etc.) – leave untouched
Expand Down Expand Up @@ -360,6 +350,24 @@ def __dir__(self) -> list[str]:
) # pragma: no cover


def _normalize_list_vals(
value, aliases_for_field: dict[str, Any], field: str, show_warnings: bool
) -> set[str]:
out = set()
for v in value:
normalized = aliases_for_field.get(v, [v])

if normalized != v and show_warnings:
warnings.warn(
message=f"Value aliasing: {field}='{v}' → {field}=['{','.join(n for n in normalized)}','{v}']",
category=UserWarning,
stacklevel=4,
)
out.add(v)
[out.add(n) for n in normalized] # type: ignore[func-returns-value]
return out


# Load CMIP to ACCESS variable mappings
_CMIP_TO_ACCESS_MAPPINGS: dict[str, tuple[str]] = _load_cmip_mappings()

Expand Down
45 changes: 44 additions & 1 deletion tests/test_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_list_values(self, sample_datastore_path, show_warnings):
# list nonsense
with pytest.warns(UserWarning) as warning_record:
wrapped_cat.search(variable=["ci", "cl", "tas"])
assert len(warning_record) == 1
assert len(warning_record) == 4
else:
with warnings.catch_warnings():
warnings.simplefilter("error")
Expand Down Expand Up @@ -347,6 +347,29 @@ def test_unwrap(self, sample_datastore_path):
unwrapped = wrapped_cat.unwrap()
assert unwrapped is sample_datastore_path

def test_no_fallback_string_iteration(self, sample_datastore_path):
"""Test that string values are not iterated over if aliaed_for field turfs up
nothing"""

sample_datastore_path = esm_datastore(
sample_datastore_path, columns_with_iterables=["variable"]
)
wrapped_cat = AliasedESMCatalog(
sample_datastore_path,
field_aliases=ESM_FIELD_ALIASES,
value_aliases=VALUE_ALIASES,
)

inp_search = {"variable": ["field_without_aliases"]}

ret = wrapped_cat._normalise_kwargs(
inp_search
) # Should not raise an error or iterate over string characters

assert (
ret == inp_search
) # Should have passed through the original value in a list


@pytest.mark.filterwarnings("ignore:Value aliasing")
class TestAliasedDataframeCatalog:
Expand Down Expand Up @@ -531,6 +554,26 @@ def test_unwrap(self, tmp_dataframe_catalog):

assert unwrapped is tmp_dataframe_catalog

def test_no_fallback_string_iteration(self, tmp_dataframe_catalog):
"""Test that string values are not iterated over if aliaed_for field turfs up
nothing"""

catalog = AliasedDataframeCatalog(
tmp_dataframe_catalog,
field_aliases=DATAFRAME_FIELD_ALIASES,
value_aliases=VALUE_ALIASES,
)

inp_search = {"variable": ["field_without_aliases"]}

ret = catalog._normalise_kwargs(
inp_search
) # Should not raise an error or iterate over string characters

assert (
ret == inp_search
) # Should have passed through the original value in a list


@pytest.mark.parametrize(
"mock_target,side_effect",
Expand Down
Loading