From 35eed50d010b3d5d3d7067fd8020ba40585a1c88 Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Tue, 10 Feb 2026 21:58:24 -0700 Subject: [PATCH 1/4] Add remap/spatial_coords_remap.py and update remap functions accordingly --- uxarray/remap/bilinear.py | 18 +- uxarray/remap/inverse_distance_weighted.py | 18 +- uxarray/remap/nearest_neighbor.py | 16 +- uxarray/remap/spatial_coords_remap.py | 385 +++++++++++++++++++++ uxarray/remap/utils.py | 23 +- 5 files changed, 424 insertions(+), 36 deletions(-) create mode 100644 uxarray/remap/spatial_coords_remap.py diff --git a/uxarray/remap/bilinear.py b/uxarray/remap/bilinear.py index 1ebfd0fb2..52261d44b 100644 --- a/uxarray/remap/bilinear.py +++ b/uxarray/remap/bilinear.py @@ -29,7 +29,7 @@ def _bilinear( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", ) -> np.ndarray: """Bilinear Remapping between two grids, mapping data that resides on the corner nodes, edge centers, or face centers on the source grid to the @@ -39,8 +39,8 @@ def _bilinear( --------- source_uxda : UxDataArray Source UxDataArray - remap_to : str, default="nodes" - Location of where to map data, either "nodes", "edge centers", or "face centers" + remap_to : str, default="faces" + Which grid element receives the remapped values, either "nodes", "edges", or "faces" Returns ------- @@ -49,7 +49,7 @@ def _bilinear( """ # ensure array is a np.ndarray - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Ensure the destination grid is normalized destination_grid.normalize_cartesian_coordinates() @@ -70,12 +70,12 @@ def _bilinear( dual = source.uxgrid.get_dual() # get destination coordinate pairs - point_xyz = _prepare_points(destination_grid, destination_dim) + point_xyz = _prepare_points(destination_grid, remap_to) weights, indices = _barycentric_weights( point_xyz=point_xyz, dual=dual, - data_size=getattr(destination_grid, f"n_{KDTREE_DIM_MAP[destination_dim]}"), + data_size=getattr(destination_grid, f"n_{KDTREE_DIM_MAP[remap_to]}"), source_grid=ds.uxgrid, ) @@ -87,8 +87,8 @@ def _bilinear( inds, w = indices, weights # pack indices & weights into tiny DataArrays: - indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[destination_dim], "k"]) - weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[destination_dim], "k"]) + indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[remap_to], "k"]) + weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[remap_to], "k"]) # gather the k neighbor values: da_k = da.isel({source_dim: indexer}, ignore_grid=True) @@ -103,7 +103,7 @@ def _bilinear( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/inverse_distance_weighted.py b/uxarray/remap/inverse_distance_weighted.py index 487cfe09f..e1dd6e14f 100644 --- a/uxarray/remap/inverse_distance_weighted.py +++ b/uxarray/remap/inverse_distance_weighted.py @@ -52,7 +52,7 @@ def _idw_weights(distances, power): def _inverse_distance_weighted_remap( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", power: int = 2, k: int = 8, ): @@ -68,8 +68,8 @@ def _inverse_distance_weighted_remap( The data to be remapped. destination_grid : Grid The UXarray grid instance on which to interpolate data. - destination_dim : {'n_node', 'n_edge', 'n_face'}, default='n_face' - The spatial dimension on `destination_grid` to receive interpolated values. + remap_to : {'nodes', 'edges', 'faces'}, default='faces' + Which grid element receives the remapped values, either "nodes", "edges", or "faces" power : int, default=2 Exponent in the inverse-distance weighting function. Larger values emphasize closer neighbors. @@ -88,9 +88,9 @@ def _inverse_distance_weighted_remap( """ # Fall back onto nearest neighbor if k == 1: - return _nearest_neighbor_remap(source, destination_grid, destination_dim) + return _nearest_neighbor_remap(source, destination_grid, remap_to) - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Perform remapping on a UxDataset ds, is_da, name = _to_dataset(source) @@ -106,7 +106,7 @@ def _inverse_distance_weighted_remap( ds.uxgrid, destination_grid, src_dim, - destination_dim, + remap_to, k=k, return_distances=True, ) @@ -123,8 +123,8 @@ def _inverse_distance_weighted_remap( inds, w = indices_weights_map[source_dim] # pack indices & weights into tiny DataArrays: - indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[destination_dim], "k"]) - weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[destination_dim], "k"]) + indexer = xr.DataArray(inds, dims=[LABEL_TO_COORD[remap_to], "k"]) + weight_da = xr.DataArray(w, dims=[LABEL_TO_COORD[remap_to], "k"]) # gather the k neighbor values: da_k = da.isel({source_dim: indexer}, ignore_grid=True) @@ -139,7 +139,7 @@ def _inverse_distance_weighted_remap( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/nearest_neighbor.py b/uxarray/remap/nearest_neighbor.py index 7e4fc197e..30fa16c75 100644 --- a/uxarray/remap/nearest_neighbor.py +++ b/uxarray/remap/nearest_neighbor.py @@ -75,7 +75,7 @@ def _nearest_neighbor_query( def _nearest_neighbor_remap( source: UxDataArray | UxDataset, destination_grid: Grid, - destination_dim: str = "n_face", + remap_to: str = "faces", ): """ Apply nearest-neighbor remapping from a UXarray object onto another grid. @@ -88,15 +88,15 @@ def _nearest_neighbor_remap( The data array or dataset to be remapped. destination_grid : Grid The UXarray Grid instance to which data will be remapped. - destination_dim : str, default='n_face' - The spatial dimension on the destination grid ('n_node', 'n_edge', 'n_face'). + remap_to : str, default='faces' + Which grid element receives the remapped values, either 'nodes', 'edges', 'faces'). Returns ------- UxDataArray or UxDataset A new UXarray object with data values assigned to `destination_grid`. """ - _assert_dimension(destination_dim) + _assert_dimension(remap_to) # Perform remapping on a UxDataset ds, is_da, name = _to_dataset(source) @@ -106,9 +106,7 @@ def _nearest_neighbor_remap( # Build Nearest Neighbor Index Arrays indices_map: dict[str, np.ndarray] = { - src_dim: _nearest_neighbor_query( - ds.uxgrid, destination_grid, src_dim, destination_dim - ) + src_dim: _nearest_neighbor_query(ds.uxgrid, destination_grid, src_dim, remap_to) for src_dim in dims_to_remap } remapped_vars = {} @@ -122,7 +120,7 @@ def _nearest_neighbor_remap( indexer = xr.DataArray( indices, dims=[ - LABEL_TO_COORD[destination_dim], + LABEL_TO_COORD[remap_to], ], ) @@ -134,7 +132,7 @@ def _nearest_neighbor_remap( remapped_vars[name] = da ds_remapped = _construct_remapped_ds( - source, remapped_vars, destination_grid, destination_dim + source, remapped_vars, destination_grid, remap_to ) return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/spatial_coords_remap.py b/uxarray/remap/spatial_coords_remap.py new file mode 100644 index 000000000..afec55246 --- /dev/null +++ b/uxarray/remap/spatial_coords_remap.py @@ -0,0 +1,385 @@ +import warnings +from typing import Dict, Literal, Optional, Tuple + +import xarray as xr + +from uxarray.core.dataarray import UxDataArray +from uxarray.grid.grid import Grid + + +class SpatialCoordsRemapper: + """Ensures remapping spatial coordinates between the source and destination grid for the remapping functions. + It may include remapping of values, renaming, and removal of some of the coordinates with respect to the + dimensions of source data & coordinates and the `remap_to` selection.""" + + # CF attributes that indicate coordinate type + CF_LAT_ATTRS = ["latitude", "projection_y_coordinate"] + CF_LON_ATTRS = ["longitude", "projection_x_coordinate"] + + # CF units that indicate coordinate type + CF_LAT_UNITS = ["degrees_north", "degree_north", "degree_n"] + CF_LON_UNITS = ["degrees_east", "degree_east", "degree_e"] + + def __init__( + self, + source: UxDataArray, + destination_grid: Grid, + remap_to: Literal["nodes", "faces", "edges"], + ): + """ + Initialize spatial coordinate remapper for UXarray's remapping functions. + + Parameters + ---------- + source : UxDataArray + Source data array that is being remapped to the `destination_grid`. + destination_grid : Grid + Destination grid that `source` is being remapped to. + remap_to : str + Which grid element receives the remapped values, either 'nodes', 'faces', or 'edges'. + """ + + if source is None: + raise ValueError( + "`source` must be provided for spatial coordinates remapping." + ) + + if destination_grid is None: + raise ValueError( + "`destination_grid` must be provided for spatial coordinates remapping." + ) + + self.destination_grid = destination_grid + self.source = source + self.remap_to = remap_to + + def _get_destination_grid_coords(self) -> Dict[str, xr.DataArray]: + """ + Get the spatial coordinates of the destination grid corresponding to `remap_to`. + + Returns + ------- + Dict[str, xr.DataArray] + Dictionary with 'lon' and 'lat' coordinate arrays + """ + if self.remap_to == "nodes": + return { + "lon": self.destination_grid.node_lon, + "lat": self.destination_grid.node_lat, + } + elif self.remap_to == "faces": + return { + "lon": self.destination_grid.face_lon, + "lat": self.destination_grid.face_lat, + } + elif self.remap_to == "edges": + return { + "lon": self.destination_grid.edge_lon, + "lat": self.destination_grid.edge_lat, + } + else: + raise ValueError( + f"Unknown `remap_to`: {self.remap_to}. Must be either 'nodes', 'faces', or 'edges'." + ) + + def _find_source_coords(self) -> Dict[str, Dict[str, Tuple[str, str]]]: + """ + Find spatial coordinate variables in `source` by checking their attributes, units, and axes. + + Returns + ------- + Dict[str, Dict[str, Tuple[str, str]]] + Nested dictionary structure: + - First level keys: dimension names (e.g., 'n_face', 'n_node', 'n_edge') + - Second level keys: spatial identifier ('lat' or 'lon') + - Values: (coord_var_name, standard_name) tuples + + Example output would look like: + { + 'n_face': { + 'lat': ('Mesh2_face_y', 'latitude'), + 'lon': ('Mesh2_face_x', 'longitude') + }, + 'n_node': { + 'lat': ('Mesh2_node_y', 'latitude'), + 'lon': ('Mesh2_node_x', 'longitude') + } + } + """ + + coords_by_dim = {} + + # Check all coordinates in `source` + for coord_name in self.source.coords: + coord = self.source.coords[coord_name] + + # Skip if in rare case this coordinate doesn't have dimensions or has multiple dimensions + if not hasattr(coord, "dims") or len(coord.dims) != 1: + continue + + dim_name = coord.dims[0] + + # Determine if this is a spatial coordinate by checking attributes + is_spatial = False + coord_type = None # will be 'lat' or 'lon' later + + if hasattr(coord, "attrs"): + # Check `standard_name` first + if "standard_name" in coord.attrs: + std_name = coord.attrs["standard_name"].lower() + if std_name in self.CF_LAT_ATTRS: + is_spatial = True + coord_type = "lat" + elif std_name in self.CF_LON_ATTRS: + is_spatial = True + coord_type = "lon" + + # Check units if standard_name didn't work + if not is_spatial and "units" in coord.attrs: + units = coord.attrs["units"].lower() + if any(u in units for u in self.CF_LAT_UNITS): + is_spatial = True + coord_type = "lat" + elif any(u in units for u in self.CF_LON_UNITS): + is_spatial = True + coord_type = "lon" + + # Check axis attribute as last chance + if not is_spatial and "axis" in coord.attrs: + axis = coord.attrs["axis"].upper() + if axis == "Y": + is_spatial = True + coord_type = "lat" + elif axis == "X": + is_spatial = True + coord_type = "lon" + + # If a spatial coord is found and `coord_type` is identified in `source` + if is_spatial and coord_type: + # Initialize `coords_by_dim` that will be returned at the end + if dim_name not in coords_by_dim: + coords_by_dim[dim_name] = {} + + # Store the coordinate variable + standard_name = coord.attrs.get("standard_name", coord_type) + coords_by_dim[dim_name][coord_type] = (coord_name, standard_name) + + return coords_by_dim + + def _get_element_type_from_dimension(self, dim_name: str) -> Optional[str]: + """ + Determine element type (i.e. 'nodes', 'faces', or 'edges') from dimension name. + + Parameters + ---------- + dim_name : str + Dimension name (e.g., 'n_face', 'nMesh2_face', etc.) + + Returns + ------- + Optional[str] + Element type ('nodes', 'faces', 'edges') or None + """ + dim_lower = dim_name.lower() + if "face" in dim_lower: + return "faces" + elif "node" in dim_lower: + return "nodes" + elif "edge" in dim_lower: + return "edges" + return None + + def _rename_coord_for_new_dimension( + self, original_name: str, old_element: str, new_element: str + ) -> str: + """ + Rename a coordinate variable when changing from one element type to another, which occurs when the `remap_to` + element does not match the `source` dimension. + + Parameters + ---------- + original_name : str + Original coordinate variable name + old_element : str + Old element type ('nodes', 'faces', 'edges') + new_element : str + New element type ('nodes', 'faces', 'edges') + + Returns + ------- + str + New coordinate name with element type updated + """ + # Map plural to singular + element_type_to_coord_name_string = { + "nodes": "node", + "faces": "face", + "edges": "edge", + } + + old_coord_name_string = element_type_to_coord_name_string[old_element] + new_coord_name_string = element_type_to_coord_name_string[new_element] + + # Try to replace the old element name in the coordinate name + # Handle both singular and plural forms + new_name = original_name + + # Case-sensitive replacements + # e.g. "*face*" -> "*node*" + new_name = new_name.replace(old_coord_name_string, new_coord_name_string) + # e.g. "*faces*" -> "*nodes*" + new_name = new_name.replace(old_element, new_element) + # e.g. "*FACE*" -> "*NODE*" + new_name = new_name.replace( + old_coord_name_string.upper(), new_coord_name_string.upper() + ) + # e.g. "*FACES*" -> "*NODES*" + new_name = new_name.replace(old_element.upper(), new_element.upper()) + # e.g. "*Face*" -> "*Node*" + new_name = new_name.replace( + old_coord_name_string.capitalize(), new_coord_name_string.capitalize() + ) + # e.g. "*Faces*" -> "*Nodes*" + new_name = new_name.replace(old_element.capitalize(), new_element.capitalize()) + + return new_name + + def construct_output_coords(self) -> Dict[str, xr.DataArray]: + """ + Construct spatial coordinates for the remapped output by finding spatial coordinate variables, if any, + in the source data and employing a logic as follows: + + Logic: + ------ + If `remap_to` matches the `source` dimension (e.g. `source` on face centers` and `remap_to="faces"` etc.) + - Swap values of spatial coords, which are defined on the same dimension as `source`, with + values of the corresponding coords from `destination_grid` + - Remove spatial coords defined on different dimensions than `source` and display a warning about it + + Else (if `remap_to` doesn't match `source` dim (e.g. `source` on face centers but `remap_to="nodes"` etc.)) + - Swap values of spatial coords, which are defined on the same dimension as `source`, with + values of the coords from `destination_grid` that are defined on the `remap_to` dimension. + - Rename these coords to reflect new element type (e.g. 'face_x' → 'node_x') + - Remove other spatial coords and display a warning about it + + Returns + ------- + Dict[str, xr.DataArray] + Dictionary mapping output coordinate variables to their new values + """ + + # Find spatial coordinate variables in `source` by checking their attributes and organize them by dimension + coord_vars_by_dim = self._find_source_coords() + + if not coord_vars_by_dim: + warnings.warn( + "No spatial coordinate variables found in `source`.", + UserWarning, + stacklevel=2, + ) + return {} + + # Get the dimension that `source` is defined on + source_dims = list(self.source.dims) + if len(source_dims) == 0: + raise ValueError("Source data has no dimensions") + + # Find the primary spatial dimension (should be n_face, n_node, or n_edge) + source_spatial_dim = None + for dim in source_dims: + if self._get_element_type_from_dimension(dim) is not None: + source_spatial_dim = dim + break + + if source_spatial_dim is None: + raise ValueError( + f"Could not identify spatial dimension in `source` dims: {source_dims}" + ) + + source_element_type = self._get_element_type_from_dimension(source_spatial_dim) + + # Get destination grid values for the remap_to element + dest_grid_coords = self._get_destination_grid_coords() + + output_coords = {} + removed_coords = [] + renamed_coords = [] + + # Logic for the remapped spatial coords construction starts here + # If `remap_to` matches `source` dimension + if source_element_type == self.remap_to: + # Swap coords on matching dimension + if source_spatial_dim in coord_vars_by_dim: + for coord_type in ["lat", "lon"]: + if coord_type in coord_vars_by_dim[source_spatial_dim]: + source_coord_name, std_name = coord_vars_by_dim[ + source_spatial_dim + ][coord_type] + + out_name = source_coord_name + + # Assign destination grid values + output_coords[out_name] = dest_grid_coords[coord_type] + + # Remove coords on other dimensions (actually, exclude them from the output and keep track of them) + for dim_name, coords_dict in coord_vars_by_dim.items(): + if dim_name != source_spatial_dim: + for coord_type, (coord_name, _) in coords_dict.items(): + removed_coords.append(coord_name) + + if removed_coords: + warnings.warn( + f"Removing spatial coordinates on non-matching dimensions: {removed_coords}.", + UserWarning, + stacklevel=2, + ) + + # `remap_to` differs from `source` dimension + else: + warnings.warn( + f"`source` dimention:('{source_spatial_dim}') but remapped to ('{self.remap_to}'). Coords " + f"will be swapped to '{self.remap_to}' dimension and renamed accordingly.", + UserWarning, + stacklevel=2, + ) + + # Swap and rename (as needed) coords from source dimension + if source_spatial_dim in coord_vars_by_dim: + for coord_type in ["lat", "lon"]: + if coord_type in coord_vars_by_dim[source_spatial_dim]: + source_coord_name, std_name = coord_vars_by_dim[ + source_spatial_dim + ][coord_type] + + # Rename to reflect new element type + out_name = self._rename_coord_for_new_dimension( + source_coord_name, source_element_type, self.remap_to + ) + if out_name != source_coord_name: + renamed_coords.append((source_coord_name, out_name)) + + # Assign destination grid values on remap_to dimension + output_coords[out_name] = dest_grid_coords[coord_type] + + # Remove coords on other dimensions (actually, exclude them from the output and keep track of them) + for dim_name, coords_dict in coord_vars_by_dim.items(): + if dim_name != source_spatial_dim: + for coord_type, (coord_name, _) in coords_dict.items(): + removed_coords.append(coord_name) + + if renamed_coords: + for old, new in renamed_coords: + warnings.warn( + f"Renamed coordinate '{old}' → '{new}' due to dimension change.", + UserWarning, + stacklevel=2, + ) + + if removed_coords: + warnings.warn( + f"Removing spatial coordinates on non-matching dimensions: {removed_coords}.", + UserWarning, + stacklevel=2, + ) + + return output_coords diff --git a/uxarray/remap/utils.py b/uxarray/remap/utils.py index c60a9c517..621840314 100644 --- a/uxarray/remap/utils.py +++ b/uxarray/remap/utils.py @@ -1,5 +1,3 @@ -from copy import deepcopy - import numpy as np import uxarray.core.dataset @@ -57,7 +55,7 @@ def _assert_dimension(dim): raise ValueError(f"Invalid spatial dimension: {dim!r}") -def _construct_remapped_ds(source, remapped_vars, destination_grid, destination_dim): +def _construct_remapped_ds(source, remapped_vars, destination_grid, remap_to): """ Construct a new UxDataset from remapped data variables and updated coordinates. @@ -69,22 +67,29 @@ def _construct_remapped_ds(source, remapped_vars, destination_grid, destination_ Mapping of variable names to their remapped DataArrays. destination_grid : Grid The UXarray grid instance representing the new topology. - destination_dim : str - The spatial dimension name (e.g., 'n_face') for the destination grid. + remap_to : str + Which grid element receives the remapped values, either "nodes", "edges", or "faces" Returns ------- UxDataset A new dataset containing only the remapped variables and retained coordinates. """ - destination_coords = deepcopy(source.coords) - if destination_dim in destination_coords: - del destination_coords[destination_dim] + + from uxarray.remap.spatial_coords_remap import SpatialCoordsRemapper + + # Ensure handling of spatial coordinates between `source` and `destination_grid` for the remapped output + # with respect to the dimensions of data & coordinate and the `remap_to` selection. See the class + # definition and functions for detailed information + coords_remapper = SpatialCoordsRemapper( + source=source, destination_grid=destination_grid, remap_to=remap_to + ) + output_coords = coords_remapper.construct_output_coords() ds_remapped = uxarray.core.dataset.UxDataset( data_vars=remapped_vars, uxgrid=destination_grid, - coords=destination_coords, + coords=output_coords, ) return ds_remapped From 4a981259312162c6fafb31d178def6281f2b981f Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Wed, 11 Feb 2026 12:56:05 -0700 Subject: [PATCH 2/4] Finalize the spatial coords remapper --- uxarray/remap/spatial_coords_remap.py | 132 ++++++++------------------ uxarray/remap/utils.py | 4 +- 2 files changed, 44 insertions(+), 92 deletions(-) diff --git a/uxarray/remap/spatial_coords_remap.py b/uxarray/remap/spatial_coords_remap.py index afec55246..516b3ec8b 100644 --- a/uxarray/remap/spatial_coords_remap.py +++ b/uxarray/remap/spatial_coords_remap.py @@ -82,32 +82,24 @@ def _get_destination_grid_coords(self) -> Dict[str, xr.DataArray]: f"Unknown `remap_to`: {self.remap_to}. Must be either 'nodes', 'faces', or 'edges'." ) - def _find_source_coords(self) -> Dict[str, Dict[str, Tuple[str, str]]]: + def _find_source_coords(self) -> Dict[str, Tuple[str, str]]: """ Find spatial coordinate variables in `source` by checking their attributes, units, and axes. Returns ------- - Dict[str, Dict[str, Tuple[str, str]]] - Nested dictionary structure: - - First level keys: dimension names (e.g., 'n_face', 'n_node', 'n_edge') - - Second level keys: spatial identifier ('lat' or 'lon') - - Values: (coord_var_name, standard_name) tuples + Dict[str, Tuple[str, str]] + Dictionary with keys as spatial identifiers ('lat' or 'lon') and values as + (coord_var_name, standard_name) tuples Example output would look like: { - 'n_face': { - 'lat': ('Mesh2_face_y', 'latitude'), - 'lon': ('Mesh2_face_x', 'longitude') - }, - 'n_node': { - 'lat': ('Mesh2_node_y', 'latitude'), - 'lon': ('Mesh2_node_x', 'longitude') - } + 'lat': ('Mesh2_face_y', 'latitude'), + 'lon': ('Mesh2_face_x', 'longitude') } """ - coords_by_dim = {} + source_coords = {} # Check all coordinates in `source` for coord_name in self.source.coords: @@ -117,8 +109,6 @@ def _find_source_coords(self) -> Dict[str, Dict[str, Tuple[str, str]]]: if not hasattr(coord, "dims") or len(coord.dims) != 1: continue - dim_name = coord.dims[0] - # Determine if this is a spatial coordinate by checking attributes is_spatial = False coord_type = None # will be 'lat' or 'lon' later @@ -156,15 +146,11 @@ def _find_source_coords(self) -> Dict[str, Dict[str, Tuple[str, str]]]: # If a spatial coord is found and `coord_type` is identified in `source` if is_spatial and coord_type: - # Initialize `coords_by_dim` that will be returned at the end - if dim_name not in coords_by_dim: - coords_by_dim[dim_name] = {} - # Store the coordinate variable standard_name = coord.attrs.get("standard_name", coord_type) - coords_by_dim[dim_name][coord_type] = (coord_name, standard_name) + source_coords[coord_type] = (coord_name, standard_name) - return coords_by_dim + return source_coords def _get_element_type_from_dimension(self, dim_name: str) -> Optional[str]: """ @@ -247,20 +233,17 @@ def _rename_coord_for_new_dimension( def construct_output_coords(self) -> Dict[str, xr.DataArray]: """ Construct spatial coordinates for the remapped output by finding spatial coordinate variables, if any, - in the source data and employing a logic as follows: + in `source` and employing a logic as follows: Logic: ------ If `remap_to` matches the `source` dimension (e.g. `source` on face centers` and `remap_to="faces"` etc.) - - Swap values of spatial coords, which are defined on the same dimension as `source`, with - values of the corresponding coords from `destination_grid` - - Remove spatial coords defined on different dimensions than `source` and display a warning about it + - Swap values of spatial coords with values of the corresponding coords from `destination_grid` Else (if `remap_to` doesn't match `source` dim (e.g. `source` on face centers but `remap_to="nodes"` etc.)) - - Swap values of spatial coords, which are defined on the same dimension as `source`, with - values of the coords from `destination_grid` that are defined on the `remap_to` dimension. + - Swap values of spatial coords with values of the coords from `destination_grid` that are + defined on the `remap_to` dimension. - Rename these coords to reflect new element type (e.g. 'face_x' → 'node_x') - - Remove other spatial coords and display a warning about it Returns ------- @@ -268,10 +251,10 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: Dictionary mapping output coordinate variables to their new values """ - # Find spatial coordinate variables in `source` by checking their attributes and organize them by dimension - coord_vars_by_dim = self._find_source_coords() + # Find spatial coordinate variables in `source` by checking their attributes + source_coords = self._find_source_coords() - if not coord_vars_by_dim: + if not source_coords: warnings.warn( "No spatial coordinate variables found in `source`.", UserWarning, @@ -302,70 +285,46 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: dest_grid_coords = self._get_destination_grid_coords() output_coords = {} - removed_coords = [] - renamed_coords = [] # Logic for the remapped spatial coords construction starts here # If `remap_to` matches `source` dimension if source_element_type == self.remap_to: # Swap coords on matching dimension - if source_spatial_dim in coord_vars_by_dim: - for coord_type in ["lat", "lon"]: - if coord_type in coord_vars_by_dim[source_spatial_dim]: - source_coord_name, std_name = coord_vars_by_dim[ - source_spatial_dim - ][coord_type] - - out_name = source_coord_name - - # Assign destination grid values - output_coords[out_name] = dest_grid_coords[coord_type] - - # Remove coords on other dimensions (actually, exclude them from the output and keep track of them) - for dim_name, coords_dict in coord_vars_by_dim.items(): - if dim_name != source_spatial_dim: - for coord_type, (coord_name, _) in coords_dict.items(): - removed_coords.append(coord_name) - - if removed_coords: - warnings.warn( - f"Removing spatial coordinates on non-matching dimensions: {removed_coords}.", - UserWarning, - stacklevel=2, - ) + for coord_type in ["lat", "lon"]: + if coord_type in source_coords: + source_coord_name, std_name = source_coords[coord_type] + out_name = source_coord_name + + # Assign destination grid values + output_coords[out_name] = dest_grid_coords[coord_type] # `remap_to` differs from `source` dimension else: warnings.warn( - f"`source` dimention:('{source_spatial_dim}') but remapped to ('{self.remap_to}'). Coords " - f"will be swapped to '{self.remap_to}' dimension and renamed accordingly.", + f"Coordinates handling as part of remapping: `source` has the dimension:" + f"('{source_spatial_dim}') but is being remapped to ('{self.remap_to}'). Therefore, " + f"coordinate values will be swapped to the '{self.remap_to}' coordinates from " + f"`destination_grid` and renamed accordingly.", UserWarning, stacklevel=2, ) + renamed_coords = [] + # Swap and rename (as needed) coords from source dimension - if source_spatial_dim in coord_vars_by_dim: - for coord_type in ["lat", "lon"]: - if coord_type in coord_vars_by_dim[source_spatial_dim]: - source_coord_name, std_name = coord_vars_by_dim[ - source_spatial_dim - ][coord_type] - - # Rename to reflect new element type - out_name = self._rename_coord_for_new_dimension( - source_coord_name, source_element_type, self.remap_to - ) - if out_name != source_coord_name: - renamed_coords.append((source_coord_name, out_name)) - - # Assign destination grid values on remap_to dimension - output_coords[out_name] = dest_grid_coords[coord_type] - - # Remove coords on other dimensions (actually, exclude them from the output and keep track of them) - for dim_name, coords_dict in coord_vars_by_dim.items(): - if dim_name != source_spatial_dim: - for coord_type, (coord_name, _) in coords_dict.items(): - removed_coords.append(coord_name) + for coord_type in ["lat", "lon"]: + if coord_type in source_coords: + source_coord_name, std_name = source_coords[coord_type] + + # Rename to reflect new element type + out_name = self._rename_coord_for_new_dimension( + source_coord_name, source_element_type, self.remap_to + ) + if out_name != source_coord_name: + renamed_coords.append((source_coord_name, out_name)) + + # Assign destination grid values on remap_to dimension + output_coords[out_name] = dest_grid_coords[coord_type] if renamed_coords: for old, new in renamed_coords: @@ -375,11 +334,4 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: stacklevel=2, ) - if removed_coords: - warnings.warn( - f"Removing spatial coordinates on non-matching dimensions: {removed_coords}.", - UserWarning, - stacklevel=2, - ) - return output_coords diff --git a/uxarray/remap/utils.py b/uxarray/remap/utils.py index 621840314..cefcb606f 100644 --- a/uxarray/remap/utils.py +++ b/uxarray/remap/utils.py @@ -79,8 +79,8 @@ def _construct_remapped_ds(source, remapped_vars, destination_grid, remap_to): from uxarray.remap.spatial_coords_remap import SpatialCoordsRemapper # Ensure handling of spatial coordinates between `source` and `destination_grid` for the remapped output - # with respect to the dimensions of data & coordinate and the `remap_to` selection. See the class - # definition and functions for detailed information + # with respect to the source dimension and `remap_to` selection. See the class definition and functions + # for detailed information coords_remapper = SpatialCoordsRemapper( source=source, destination_grid=destination_grid, remap_to=remap_to ) From e673079da444f58b8d9e7b5d9579bb79f2c3f18b Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Wed, 11 Feb 2026 12:56:34 -0700 Subject: [PATCH 3/4] Add tests --- test/test_remap.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_remap.py b/test/test_remap.py index 186faae0d..836fd517e 100644 --- a/test/test_remap.py +++ b/test/test_remap.py @@ -265,3 +265,45 @@ def test_b_quadrilateral(gridpath, datasetpath): out = uxds['var2'].remap.bilinear(destination_grid=dest) assert out.size > 0 + +def test_b_coords_remap_to_faces(gridpath): + """Bilinear remap should change the array when remap_to != source.""" + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(gridpath("ugrid", "geoflow-small", "grid.nc")) + + uxda_with_coords = ux.core.UxDataArray( + data=uxds["latCell"], + uxgrid=uxds.uxgrid, + coords={"Mesh2_face_lat": uxds.uxgrid.face_lat, + "Mesh_Face_lon": uxds.uxgrid.face_lon, + } + ) + + da_remap_b = uxda_with_coords.remap.bilinear( + destination_grid=dest, remap_to="faces" + ) + + assert (da_remap_b.Mesh_Face_lon.size == dest.face_lon.size) + assert np.array_equal(da_remap_b.Mesh_Face_lon.values, dest.face_lon.values) + +def test_b_coords_remap_to_nodes(gridpath): + """Bilinear remap should change the array when remap_to != source.""" + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(gridpath("ugrid", "geoflow-small", "grid.nc")) + + uxda_with_coords = ux.core.UxDataArray( + data=uxds["latCell"], + uxgrid=uxds.uxgrid, + coords={"Mesh2_face_lat": uxds.uxgrid.face_lat, + "Mesh_Face_lon": uxds.uxgrid.face_lon, + } + ) + + da_remap_b = uxda_with_coords.remap.bilinear( + destination_grid=dest, remap_to="nodes" + ) + + assert (da_remap_b.Mesh_Node_lon.size == dest.node_lon.size) + assert np.array_equal(da_remap_b.Mesh_Node_lon.values, dest.node_lon.values) From 69a4f80523b52c5e1408ae0f132ae6f1043a61b1 Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Wed, 11 Feb 2026 16:10:22 -0700 Subject: [PATCH 4/4] Add handling of cartesian xa and y coords in spatial_coords_remap --- uxarray/conventions/ugrid.py | 6 +++ uxarray/remap/spatial_coords_remap.py | 73 ++++++++++++++++----------- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/uxarray/conventions/ugrid.py b/uxarray/conventions/ugrid.py index 707c1dd79..f0f19de0b 100644 --- a/uxarray/conventions/ugrid.py +++ b/uxarray/conventions/ugrid.py @@ -84,12 +84,14 @@ "standard_name": "x", "long name": "Cartesian x location of the corner nodes of each face", "units": "meters", + "axis": "X", } NODE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the corner nodes of each face", "units": "meters", + "axis": "Y", } NODE_Z_ATTRS = { @@ -104,12 +106,14 @@ "standard_name": "x", "long name": "Cartesian x location of the center of each edge", "units": "meters", + "axis": "X", } EDGE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the center of each edge", "units": "meters", + "axis": "Y", } EDGE_Z_ATTRS = { @@ -124,12 +128,14 @@ "standard_name": "x", "long name": "Cartesian x location of the center of each face", "units": "meters", + "axis": "X", } FACE_Y_ATTRS = { "standard_name": "y", "long name": "Cartesian y location of the center of each face", "units": "meters", + "axis": "Y", } FACE_Z_ATTRS = { diff --git a/uxarray/remap/spatial_coords_remap.py b/uxarray/remap/spatial_coords_remap.py index 516b3ec8b..3dcb1aca5 100644 --- a/uxarray/remap/spatial_coords_remap.py +++ b/uxarray/remap/spatial_coords_remap.py @@ -6,20 +6,27 @@ from uxarray.core.dataarray import UxDataArray from uxarray.grid.grid import Grid +COORD_TYPES = { + "LON": "lon", + "LAT": "lat", + "CART_X": "X", + "CART_Y": "Y", +} + +# CF attributes that indicate coordinate type +CF_LAT_ATTRS = ["latitude", "projection_y_coordinate"] +CF_LON_ATTRS = ["longitude", "projection_x_coordinate"] + +# CF units that indicate coordinate type +CF_LAT_UNITS = ["degrees_north", "degree_north", "degree_n"] +CF_LON_UNITS = ["degrees_east", "degree_east", "degree_e"] + class SpatialCoordsRemapper: """Ensures remapping spatial coordinates between the source and destination grid for the remapping functions. It may include remapping of values, renaming, and removal of some of the coordinates with respect to the dimensions of source data & coordinates and the `remap_to` selection.""" - # CF attributes that indicate coordinate type - CF_LAT_ATTRS = ["latitude", "projection_y_coordinate"] - CF_LON_ATTRS = ["longitude", "projection_x_coordinate"] - - # CF units that indicate coordinate type - CF_LAT_UNITS = ["degrees_north", "degree_north", "degree_n"] - CF_LON_UNITS = ["degrees_east", "degree_east", "degree_e"] - def __init__( self, source: UxDataArray, @@ -64,18 +71,24 @@ def _get_destination_grid_coords(self) -> Dict[str, xr.DataArray]: """ if self.remap_to == "nodes": return { - "lon": self.destination_grid.node_lon, - "lat": self.destination_grid.node_lat, + COORD_TYPES["LON"]: self.destination_grid.node_lon, + COORD_TYPES["LAT"]: self.destination_grid.node_lat, + COORD_TYPES["CART_X"]: self.destination_grid.node_x, + COORD_TYPES["CART_Y"]: self.destination_grid.node_y, } elif self.remap_to == "faces": return { - "lon": self.destination_grid.face_lon, - "lat": self.destination_grid.face_lat, + COORD_TYPES["LON"]: self.destination_grid.face_lon, + COORD_TYPES["LAT"]: self.destination_grid.face_lat, + COORD_TYPES["CART_X"]: self.destination_grid.face_x, + COORD_TYPES["CART_Y"]: self.destination_grid.face_y, } elif self.remap_to == "edges": return { - "lon": self.destination_grid.edge_lon, - "lat": self.destination_grid.edge_lat, + COORD_TYPES["LON"]: self.destination_grid.edge_lon, + COORD_TYPES["LAT"]: self.destination_grid.edge_lat, + COORD_TYPES["CART_X"]: self.destination_grid.edge_x, + COORD_TYPES["CART_Y"]: self.destination_grid.edge_y, } else: raise ValueError( @@ -117,32 +130,32 @@ def _find_source_coords(self) -> Dict[str, Tuple[str, str]]: # Check `standard_name` first if "standard_name" in coord.attrs: std_name = coord.attrs["standard_name"].lower() - if std_name in self.CF_LAT_ATTRS: + if std_name in CF_LAT_ATTRS: is_spatial = True - coord_type = "lat" - elif std_name in self.CF_LON_ATTRS: + coord_type = COORD_TYPES["LAT"] + elif std_name in CF_LON_ATTRS: is_spatial = True - coord_type = "lon" + coord_type = COORD_TYPES["LON"] # Check units if standard_name didn't work if not is_spatial and "units" in coord.attrs: units = coord.attrs["units"].lower() - if any(u in units for u in self.CF_LAT_UNITS): + if any(u in units for u in CF_LAT_UNITS): is_spatial = True - coord_type = "lat" - elif any(u in units for u in self.CF_LON_UNITS): + coord_type = COORD_TYPES["LAT"] + elif any(u in units for u in CF_LON_UNITS): is_spatial = True - coord_type = "lon" + coord_type = COORD_TYPES["LON"] # Check axis attribute as last chance if not is_spatial and "axis" in coord.attrs: axis = coord.attrs["axis"].upper() - if axis == "Y": + if axis == COORD_TYPES["CART_Y"]: is_spatial = True - coord_type = "lat" - elif axis == "X": + coord_type = COORD_TYPES["CART_Y"] + elif axis == COORD_TYPES["CART_X"]: is_spatial = True - coord_type = "lon" + coord_type = COORD_TYPES["CART_X"] # If a spatial coord is found and `coord_type` is identified in `source` if is_spatial and coord_type: @@ -290,13 +303,13 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: # If `remap_to` matches `source` dimension if source_element_type == self.remap_to: # Swap coords on matching dimension - for coord_type in ["lat", "lon"]: + for coord_type in COORD_TYPES.values(): if coord_type in source_coords: source_coord_name, std_name = source_coords[coord_type] out_name = source_coord_name # Assign destination grid values - output_coords[out_name] = dest_grid_coords[coord_type] + output_coords[out_name] = dest_grid_coords[coord_type].variable # `remap_to` differs from `source` dimension else: @@ -312,7 +325,7 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: renamed_coords = [] # Swap and rename (as needed) coords from source dimension - for coord_type in ["lat", "lon"]: + for coord_type in COORD_TYPES.values(): if coord_type in source_coords: source_coord_name, std_name = source_coords[coord_type] @@ -324,7 +337,7 @@ def construct_output_coords(self) -> Dict[str, xr.DataArray]: renamed_coords.append((source_coord_name, out_name)) # Assign destination grid values on remap_to dimension - output_coords[out_name] = dest_grid_coords[coord_type] + output_coords[out_name] = dest_grid_coords[coord_type].variable if renamed_coords: for old, new in renamed_coords: