Source code for mapchete_eo.array.convert

from typing import List, Optional, Union

import numpy as np
import numpy.ma as ma
from numpy.typing import DTypeLike
import xarray as xr
from mapchete.types import NodataVal

# dtypes from https://numpy.org/doc/stable/user/basics.types.html
_NUMPY_FLOAT_DTYPES = [
    np.half,
    np.float16,
    np.single,
    np.double,
    np.longdouble,
    np.csingle,
    np.cdouble,
    np.clongdouble,
]


[docs] def to_masked_array( xarr: Union[xr.Dataset, xr.DataArray], copy: bool = False, out_dtype: Optional[DTypeLike] = None, ) -> ma.MaskedArray: """Convert xr.DataArray to ma.MaskedArray.""" if isinstance(xarr, xr.Dataset): xarr = xarr.to_array() fill_value = xarr.attrs.get("_FillValue") if fill_value is None: raise ValueError( "Cannot create masked_array because DataArray fill value is None" ) if out_dtype: xarr = xarr.astype(out_dtype, copy=False) if xarr.dtype in _NUMPY_FLOAT_DTYPES: return ma.masked_values(xarr, fill_value, copy=copy, shrink=False) else: out = ma.masked_equal(xarr, fill_value, copy=copy) # in case of a shrinked mask we have to expand it to the full array shape if not isinstance(out.mask, np.ndarray): out.mask = np.full(out.mask.shape, out.mask, dtype=bool) return out
[docs] def to_dataarray( masked_arr: ma.MaskedArray, nodataval: NodataVal = None, name: Optional[str] = None, band_names: Optional[List[str]] = None, band_axis_name: str = "bands", x_axis_name: str = "x", y_axis_name: str = "y", attrs: Optional[dict] = None, ) -> xr.DataArray: """ Convert ma.MaskedArray to xr.DataArray. Depending on whether the array is 2D or 3D, the axes will be named accordingly. A 2-dimensional array indicates that we only have a spatial x- and y-axis. A 3rd dimension will be interpreted as bands. """ # nodata handling is weird. # # xr.DataArray cannot hold a masked_array but will turn it into # a usual NumPy array, replacing the masked values with np.nan. # However, this also seems to change the dtype to float32 which # is not desirable. nodataval = masked_arr.fill_value if nodataval is None else nodataval attrs = attrs or dict() if masked_arr.ndim == 2: dims = [x_axis_name, y_axis_name] coords = None elif masked_arr.ndim == 3: bands_count = masked_arr.shape[0] band_names = band_names or [f"{band_axis_name}-{i}" for i in range(bands_count)] dims = [band_axis_name, x_axis_name, y_axis_name] coords = {band_axis_name: band_names} else: # pragma: no cover raise TypeError("only a 2D or 3D ma.MaskedArray is allowed.") return xr.DataArray( data=masked_arr.filled(nodataval), dims=dims, name=name, attrs=dict(attrs, _FillValue=nodataval), coords=coords, )
[docs] def to_dataset( masked_arr: ma.MaskedArray, nodataval: NodataVal = None, slice_names: Optional[List[str]] = None, band_names: Optional[List[str]] = None, slices_attrs: Optional[List[Union[dict, None]]] = None, slice_axis_name: str = "time", band_axis_name: str = "bands", x_axis_name: str = "x", y_axis_name: str = "y", attrs: Optional[dict] = None, ): """Convert a 3D or 4D ma.MaskedArray to an xarray.Dataset.""" attrs = attrs or dict() nodataval = masked_arr.fill_value if nodataval is None else nodataval if masked_arr.ndim == 3: bands = masked_arr.shape[0] band_names = band_names or [f"{band_axis_name}-{i}" for i in range(bands)] raise NotImplementedError() elif masked_arr.ndim == 4: slices, bands = masked_arr.shape[:2] band_names = band_names or [f"{band_axis_name}-{i}" for i in range(bands)] slice_names = slice_names or [f"{slice_axis_name}-{i}" for i in range(slices)] slices_attrs = ( [None for _ in range(slices)] if slices_attrs is None else slices_attrs ) coords = {slice_axis_name: slice_names} return xr.Dataset( data_vars={ # every slice gets its own xarray Dataset slice_name: to_dataarray( slice_array, nodataval=nodataval, band_names=band_names, name=slice_name, attrs=slice_attrs, band_axis_name=band_axis_name, x_axis_name=x_axis_name, y_axis_name=y_axis_name, ) for slice_name, slice_attrs, slice_array in zip( slice_names, slices_attrs, masked_arr, ) }, coords=coords, attrs=dict(attrs, _FillValue=nodataval), ).transpose(slice_axis_name, band_axis_name, x_axis_name, y_axis_name) else: # pragma: no cover raise TypeError("only a 3D or 4D ma.MaskedArray is allowed.")
[docs] def to_bands_mask(arr: np.ndarray, bands: int = 1) -> np.ndarray: """Expands a 2D mask to a full band mask.""" if arr.ndim != 2: raise TypeError("input array has to have exactly 2 dimensions.") return np.repeat( np.expand_dims( arr, axis=0, ), bands, axis=0, )