from __future__ import annotations
from contextlib import contextmanager
import logging
from collections import defaultdict
from datetime import datetime
import gc
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence
from mapchete import Timer
import numpy as np
import numpy.ma as ma
from numpy.typing import DTypeLike
import xarray as xr
from mapchete.config import get_hash
from mapchete.geometry import to_shape
from mapchete.protocols import GridProtocol
from mapchete.types import NodataVals
from rasterio.enums import Resampling
from shapely.geometry import mapping
from shapely.ops import unary_union
from mapchete_eo.array.convert import to_dataarray, to_masked_array
from mapchete_eo.exceptions import (
AssetKeyError,
CorruptedProduct,
CorruptedSlice,
EmptySliceException,
EmptyStackException,
NoSourceProducts,
)
from mapchete_eo.protocols import EOProductProtocol
from mapchete_eo.sort import SortMethodConfig
from mapchete_eo.types import MergeMethod
logger = logging.getLogger(__name__)
[docs]
def products_to_np_array(
products: List[EOProductProtocol],
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
grid: Optional[GridProtocol] = None,
resampling: Resampling = Resampling.nearest,
nodatavals: NodataVals = None,
merge_products_by: Optional[str] = None,
merge_method: MergeMethod = MergeMethod.first,
sort: Optional[SortMethodConfig] = None,
product_read_kwargs: dict = {},
raise_empty: bool = True,
out_dtype: Optional[DTypeLike] = None,
read_mask: Optional[np.ndarray] = None,
) -> ma.MaskedArray:
"""Read grid window of EOProducts and merge into a 4D xarray."""
return ma.stack(
[
to_masked_array(s, out_dtype=out_dtype)
for s in generate_slice_dataarrays(
products=products,
assets=assets,
eo_bands=eo_bands,
grid=grid,
resampling=resampling,
nodatavals=nodatavals,
merge_products_by=merge_products_by,
merge_method=merge_method,
sort=sort,
product_read_kwargs=product_read_kwargs,
raise_empty=raise_empty,
read_mask=read_mask,
)
]
)
[docs]
def products_to_xarray(
products: List[EOProductProtocol],
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
grid: Optional[GridProtocol] = None,
resampling: Resampling = Resampling.nearest,
nodatavals: NodataVals = None,
slice_axis_name: str = "time",
band_axis_name: str = "bands",
x_axis_name: str = "x",
y_axis_name: str = "y",
merge_products_by: Optional[str] = None,
merge_method: MergeMethod = MergeMethod.first,
sort: Optional[SortMethodConfig] = None,
raise_empty: bool = True,
product_read_kwargs: dict = {},
read_mask: Optional[np.ndarray] = None,
) -> xr.Dataset:
"""Read grid window of EOProducts and merge into a 4D xarray."""
data_vars = [
s
for s in generate_slice_dataarrays(
products=products,
assets=assets,
eo_bands=eo_bands,
grid=grid,
resampling=resampling,
nodatavals=nodatavals,
merge_products_by=merge_products_by,
merge_method=merge_method,
sort=sort,
product_read_kwargs=product_read_kwargs,
raise_empty=raise_empty,
read_mask=read_mask,
)
]
if merge_products_by and merge_products_by not in ["date", "datetime"]:
coords = {merge_products_by: [s.name for s in data_vars]}
slice_axis_name = merge_products_by
else:
coords = {
slice_axis_name: list(
np.array(
[product.get_property("datetime") for product in products],
dtype=np.datetime64,
)
)
}
return xr.Dataset(
data_vars={s.name: s for s in data_vars},
coords=coords,
).transpose(slice_axis_name, band_axis_name, x_axis_name, y_axis_name)
[docs]
class Slice:
"""Combine multiple products into one slice."""
name: Any
products: Sequence[EOProductProtocol]
datetime: datetime
def __init__(
self,
name: Any,
products: Sequence[EOProductProtocol],
):
self.name = name
# a Slice can only be valid if it contains one or more products
if products:
self.products = products
else: # pragma: no cover
raise ValueError("at least one product must be provided.")
# calculate mean datetime
timestamps = [
product.get_property("datetime").timestamp()
for product in self.products
if product.get_property("datetime")
]
mean_timestamp = sum(timestamps) / len(timestamps)
self.datetime = datetime.fromtimestamp(mean_timestamp)
def __repr__(self) -> str:
return f"<Slice {self.name} ({len(self.products)} products)>"
@property
def __geom_interface__(self) -> Dict:
if self.products:
return mapping(
unary_union([to_shape(product) for product in self.products])
)
raise EmptySliceException
@property
def properties(self) -> Dict[str, Any]:
# generate combined properties
properties: Dict[str, Any] = {}
for key in self.products[0].item.properties.keys():
try:
properties[key] = self.get_property(key)
except ValueError:
properties[key] = None
return properties
[docs]
@contextmanager
def cached(self) -> Generator[Slice, None, None]:
"""Clear caches and run garbage collector when context manager is closed."""
yield self
with Timer() as tt:
self.clear_cached_data()
gc.collect()
logger.debug("Slice cache cleared and garbage collected in %s", tt)
[docs]
def clear_cached_data(self):
logger.debug("clear caches of all products in slice")
for product in self.products:
product.clear_cached_data()
[docs]
def get_property(self, property: str) -> Any:
"""
Return merged property over all products.
If property values are the same over all products, it will be returned. Otherwise a
ValueError is raised.
"""
# if set of value hashes has a length of 1, all values are the same
values = [get_hash(product.get_property(property)) for product in self.products]
if len(set(values)) == 1:
return self.products[0].get_property(property)
raise ValueError(
f"cannot get unique property {property} from products {self.products}"
)
[docs]
def read(
self,
merge_method: MergeMethod = MergeMethod.first,
product_read_kwargs: dict = {},
raise_empty: bool = True,
) -> ma.MaskedArray:
logger.debug("Slice: read from %s products", len(self.products))
return merge_products(
products=self.products,
merge_method=merge_method,
product_read_kwargs=product_read_kwargs,
raise_empty=raise_empty,
)
[docs]
def products_to_slices(
products: List[EOProductProtocol],
group_by_property: Optional[str] = None,
sort: Optional[SortMethodConfig] = None,
) -> List[Slice]:
"""Group products per given property into Slice objects and optionally sort slices."""
if group_by_property:
grouped = defaultdict(list)
for product in products:
grouped[product.get_property(group_by_property)].append(product)
slices = [Slice(key, products) for key, products in grouped.items()]
else:
slices = [Slice(product.id, [product]) for product in products]
# also check if slices is even a list, otherwise it will raise an error
if sort and slices:
sort_dict = sort.model_dump()
func = sort_dict.pop("func")
slices = func(slices, **sort_dict)
return slices
[docs]
def merge_products(
products: Sequence[EOProductProtocol],
merge_method: MergeMethod = MergeMethod.first,
product_read_kwargs: dict = {},
raise_empty: bool = True,
) -> ma.MaskedArray:
"""
Merge given products into one array.
"""
def read_remaining_valid_products(
products_iter: Iterator[EOProductProtocol], product_read_kwargs: dict
) -> Generator[ma.MaskedArray, None, None]:
"""Yields and reads remaining products from iterator while discarding corrupt products."""
try:
for product in products_iter:
try:
yield product.read_np_array(**product_read_kwargs)
except (AssetKeyError, CorruptedProduct) as exc:
logger.warning("skip product %s because of %s", product.id, exc)
except StopIteration:
return
if len(products) == 0: # pragma: no cover
raise NoSourceProducts("no products to merge")
# we need to deactivate raising the EmptyProductException
product_read_kwargs.update(raise_empty=False)
products_iter = iter(products)
# read first valid product
for product in products_iter:
try:
out = product.read_np_array(**product_read_kwargs)
break
except (AssetKeyError, CorruptedProduct) as exc:
logger.warning("skip product %s because of %s", product.id, exc)
else:
# we cannot do anything here, as all products are broken
raise CorruptedSlice("all products are broken here")
# fill in gaps sequentially, product by product
if merge_method == MergeMethod.first:
for new in read_remaining_valid_products(products_iter, product_read_kwargs):
masked = out.mask
# Update values at masked locations
out[masked] = new[masked]
# Update mask at masked locations (e.g., unmask them)
out.mask[masked] = new.mask[masked]
# if whole output array is filled, there is no point in reading more data
if not out.mask.any():
return out
# read all and average
elif merge_method == MergeMethod.average:
def _generate_arrays(
first_product_array: ma.MaskedArray,
remaining_product_arrays: Generator[ma.MaskedArray, None, None],
) -> Generator[ma.MaskedArray, None, None]:
"""Yield all available product arrays."""
yield first_product_array
yield from remaining_product_arrays
# explicitly specify dtype to avoid casting of integer arrays to floats
# during mean conversion:
# https://numpy.org/doc/stable/reference/generated/numpy.mean.html#numpy.mean
arrays = list(
_generate_arrays(
out,
read_remaining_valid_products(products_iter, product_read_kwargs),
)
)
# Filter out arrays that are entirely masked
valid_arrays = [a for a in arrays if not ma.getmaskarray(a).all()]
if valid_arrays:
out_dtype = out.dtype
out_fill_value = out.fill_value
stacked = ma.stack(valid_arrays, dtype=out_dtype)
out = stacked.mean(axis=0, dtype=out_dtype).astype(out_dtype, copy=False)
out.set_fill_value(out_fill_value)
else:
# All arrays were fully masked — return fully masked output
out = ma.masked_all(out.shape, dtype=out.dtype)
else: # pragma: no cover
raise NotImplementedError(f"unknown merge method: {merge_method}")
if raise_empty and out.mask.all():
raise EmptySliceException(
f"slice is empty after combining {len(products)} products"
)
return out
[docs]
def generate_slice_dataarrays(
products: List[EOProductProtocol],
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
grid: Optional[GridProtocol] = None,
resampling: Resampling = Resampling.nearest,
nodatavals: NodataVals = None,
merge_products_by: Optional[str] = None,
merge_method: MergeMethod = MergeMethod.first,
sort: Optional[SortMethodConfig] = None,
product_read_kwargs: dict = {},
raise_empty: bool = True,
read_mask: Optional[np.ndarray] = None,
) -> Iterator[xr.DataArray]:
"""
Yield products or merged products into slices as DataArrays.
"""
if len(products) == 0:
raise NoSourceProducts("no products to read")
stack_empty = True
assets = assets or []
eo_bands = eo_bands or []
variables = assets or eo_bands
# group products into slices and sort slices if configured
slices = products_to_slices(
products, group_by_property=merge_products_by, sort=sort
)
logger.debug(
"reading %s products in %s groups...",
len(products),
len(slices),
)
if isinstance(nodatavals, list):
nodataval = nodatavals[0]
elif isinstance(nodatavals, float):
nodataval = nodatavals
else:
nodataval = nodatavals
for slice in slices:
try:
# if merge_products_by is none, each slice contains just one product
# so nothing will have to be merged anyways
with slice.cached():
yield to_dataarray(
merge_products(
products=slice.products,
merge_method=merge_method,
product_read_kwargs=dict(
product_read_kwargs,
assets=assets,
eo_bands=eo_bands,
grid=grid,
resampling=resampling,
nodatavals=nodatavals,
raise_empty=raise_empty,
read_mask=read_mask,
),
raise_empty=raise_empty,
),
nodataval=nodataval,
name=slice.name,
band_names=variables,
attrs=slice.properties,
)
# if at least one slice can be yielded, the stack is not empty
stack_empty = False
except (EmptySliceException, CorruptedSlice) as exception:
logger.warning(exception)
if stack_empty:
raise EmptyStackException("all slices are empty")