from __future__ import annotations
import warnings
import logging
from functools import cached_property
from typing import Any, Callable, List, Optional, Sequence, Type, Union, Dict, Generator
import croniter
from mapchete import Bounds
import numpy as np
import numpy.ma as ma
import xarray as xr
from dateutil.tz import tzutc
from mapchete.config.parse import guess_geometry
from mapchete.formats import base
from mapchete.geometry import reproject_geometry
from mapchete.io.vector import IndexedFeatures
from mapchete.path import MPath
from mapchete.tile import BufferedTile
from mapchete.types import MPathLike, NodataVal, NodataVals
from pydantic import BaseModel, model_validator
from pystac import Item
from rasterio.enums import Resampling
from rasterio.features import geometry_mask
from shapely.geometry import mapping
from shapely.geometry.base import BaseGeometry
from mapchete_eo.exceptions import CorruptedProductMetadata, PreprocessingNotFinished
from mapchete_eo.io import (
products_to_np_array,
products_to_xarray,
read_levelled_cube_to_np_array,
read_levelled_cube_to_xarray,
)
from mapchete_eo.source import Source
from mapchete_eo.product import EOProduct
from mapchete_eo.protocols import EOProductProtocol
from mapchete_eo.settings import mapchete_eo_settings
from mapchete_eo.sort import SortMethodConfig, TargetDateSort
from mapchete_eo.time import to_datetime
from mapchete_eo.types import DateTimeLike, MergeMethod, TimeRange
logger = logging.getLogger(__name__)
[docs]
class BaseDriverConfig(BaseModel):
"""
Configuration for mapchete-eo drivers.
"""
format: str
source: Sequence[Source]
time: Optional[Union[TimeRange, List[TimeRange]]] = None
cat_baseurl: Optional[str] = None
cache: Optional[Any] = None
footprint_buffer: float = 0
area: Optional[Union[MPathLike, dict, type[BaseGeometry]]] = None
preprocessing_tasks: bool = False
search_kwargs: Optional[Dict[str, Any]] = None
[docs]
@model_validator(mode="before")
def to_list(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Expands source to list."""
for field in ["source"]:
value = values.get(field)
if value is not None and not isinstance(value, list):
values[field] = [value]
return values
[docs]
@model_validator(mode="before")
def deprecate_cat_baseurl(cls, values: Dict[str, Any]) -> Dict[str, Any]:
cat_baseurl = values.get("cat_baseurl")
if cat_baseurl: # pragma: no cover
warnings.warn(
"'cat_baseurl' will be deprecated soon. Please use 'catalog_type=static' in the source.",
category=DeprecationWarning,
stacklevel=2,
)
if values.get("source", []):
raise ValueError(
"deprecated cat_baseurl field found alongside sources."
)
values["source"] = [dict(collection=cat_baseurl, catalog_type="static")]
return values
[docs]
class EODataCube(base.InputTile):
"""Target Tile representation of input data."""
default_read_merge_method: MergeMethod = MergeMethod.first
default_read_merge_products_by: Optional[str] = None
default_read_nodataval: NodataVal = None
default_read_resampling: Resampling = Resampling.nearest
tile: BufferedTile
eo_bands: dict
time: Optional[List[TimeRange]]
area: BaseGeometry
area_pixelbuffer: int = 0
def __init__(
self,
tile: BufferedTile,
products: Optional[List[EOProductProtocol]],
eo_bands: dict,
time: Optional[List[TimeRange]] = None,
input_key: Optional[str] = None,
area: Optional[BaseGeometry] = None,
**kwargs,
) -> None:
"""Initialize."""
self.tile = tile
self._products = products
self.eo_bands = eo_bands
self.time = time
self.input_key = input_key
self.area = tile.bbox if area is None else area
[docs]
@cached_property
def products(self) -> IndexedFeatures[EOProductProtocol]:
"""
Indexed products.
"""
# during task graph processing, the products have to be fetched as preprocessing task results
if self._products is None: # pragma: no cover
return IndexedFeatures(
[
item
for item in self.preprocessing_tasks_results.values()
if not isinstance(item, CorruptedProductMetadata)
],
crs=self.tile.crs,
# by not using rtree, we avoid an edge case where products outside of process CRS bounds
# cause rtree to fail when indexing the products.
index=None,
)
# just return the prouducts as is
return IndexedFeatures(
[
item
for item in self._products
if not isinstance(item, CorruptedProductMetadata)
],
crs=self.tile.crs,
# by not using rtree, we avoid an edge case where products outside of process CRS bounds
# cause rtree to fail when indexing the products.
index=None,
)
[docs]
def read(
self,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
resampling: Optional[Union[Resampling, str]] = None,
merge_products_by: Optional[str] = None,
merge_method: Optional[MergeMethod] = None,
sort: Optional[SortMethodConfig] = None,
nodatavals: NodataVals = None,
raise_empty: bool = True,
**kwargs,
) -> xr.Dataset:
"""
Read input data into an xarray.Dataset.
"""
return products_to_xarray(
products=self.filter_products(
start_time=start_time,
end_time=end_time,
timestamps=timestamps,
time_pattern=time_pattern,
),
eo_bands=eo_bands,
assets=assets,
grid=self.tile,
raise_empty=raise_empty,
product_read_kwargs=kwargs,
sort=sort,
**self.default_read_values(
merge_products_by=merge_products_by,
merge_method=merge_method,
resampling=resampling,
nodatavals=nodatavals,
),
)
[docs]
def read_np_array(
self,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
resampling: Optional[Union[Resampling, str]] = None,
merge_products_by: Optional[str] = None,
merge_method: Optional[MergeMethod] = None,
sort: Optional[SortMethodConfig] = None,
nodatavals: NodataVals = None,
raise_empty: bool = True,
**kwargs,
) -> ma.MaskedArray:
"""
Read input data as a MaskedArray.
"""
return products_to_np_array(
products=self.filter_products(
start_time=start_time,
end_time=end_time,
timestamps=timestamps,
time_pattern=time_pattern,
),
eo_bands=eo_bands,
assets=assets,
grid=self.tile,
product_read_kwargs=kwargs,
raise_empty=raise_empty,
sort=sort,
**self.default_read_values(
merge_products_by=merge_products_by,
merge_method=merge_method,
resampling=resampling,
nodatavals=nodatavals,
),
)
[docs]
def read_levelled(
self,
target_height: int,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
resampling: Optional[Union[Resampling, str]] = None,
nodatavals: NodataVals = None,
merge_products_by: Optional[str] = None,
merge_method: Optional[MergeMethod] = None,
sort: SortMethodConfig = TargetDateSort(),
raise_empty: bool = True,
slice_axis_name: str = "layers",
band_axis_name: str = "bands",
x_axis_name: str = "x",
y_axis_name: str = "y",
**kwargs,
) -> xr.Dataset:
return read_levelled_cube_to_xarray(
products=self.filter_products(
start_time=start_time,
end_time=end_time,
timestamps=timestamps,
time_pattern=time_pattern,
),
target_height=target_height,
assets=assets,
eo_bands=eo_bands,
grid=self.tile,
raise_empty=raise_empty,
product_read_kwargs=kwargs,
slice_axis_name=slice_axis_name,
band_axis_name=band_axis_name,
x_axis_name=x_axis_name,
y_axis_name=y_axis_name,
sort=sort,
**self.default_read_values(
merge_products_by=merge_products_by,
merge_method=merge_method,
resampling=resampling,
nodatavals=nodatavals,
),
)
[docs]
def read_levelled_np_array(
self,
target_height: int,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
resampling: Optional[Union[Resampling, str]] = None,
nodatavals: NodataVals = None,
merge_products_by: Optional[str] = None,
merge_method: Optional[MergeMethod] = None,
sort: SortMethodConfig = TargetDateSort(),
raise_empty: bool = True,
**kwargs,
) -> ma.MaskedArray:
"""
Read levelled data (cubes with depth) as a MaskedArray.
Args:
target_height: Target stack height.
assets: List of asset names.
eo_bands: List of EO bands.
start_time: Start time.
end_time: End time.
timestamps: List of timestamps.
time_pattern: Time pattern.
resampling: Resampling method.
nodatavals: Nodata values.
merge_products_by: Property to merge by.
merge_method: Merge method.
sort: Sorting configuration.
raise_empty: Raise error if no data found.
Returns:
ma.MaskedArray: Output data array.
"""
return read_levelled_cube_to_np_array(
products=self.filter_products(
start_time=start_time,
end_time=end_time,
timestamps=timestamps,
time_pattern=time_pattern,
),
target_height=target_height,
assets=assets,
eo_bands=eo_bands,
grid=self.tile,
raise_empty=raise_empty,
product_read_kwargs=kwargs,
sort=sort,
**self.default_read_values(
merge_products_by=merge_products_by,
merge_method=merge_method,
resampling=resampling,
nodatavals=nodatavals,
),
)
[docs]
def read_masks(
self,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
nodatavals: NodataVals = None,
**kwargs,
):
"""
Read product masks.
Args:
start_time: Start time.
end_time: End time.
timestamps: List of timestamps.
time_pattern: Time pattern.
nodatavals: Nodata values.
Returns:
ma.MaskedArray: Mask data.
"""
from mapchete_eo.platforms.sentinel2.masks import read_masks
return read_masks(
products=self.filter_products(
start_time=start_time,
end_time=end_time,
timestamps=timestamps,
time_pattern=time_pattern,
),
grid=self.tile,
nodatavals=nodatavals,
product_read_kwargs=kwargs,
)
[docs]
def filter_products(
self,
start_time: Optional[DateTimeLike] = None,
end_time: Optional[DateTimeLike] = None,
timestamps: Optional[List[DateTimeLike]] = None,
time_pattern: Optional[str] = None,
):
"""
Return a filtered list of input products.
"""
if any([start_time, end_time, timestamps]): # pragma: no cover
raise NotImplementedError("time subsets are not yet implemented")
if time_pattern:
# filter products by time pattern
return [
product
for product in self.products
if product.item.datetime
in [
t.replace(tzinfo=tzutc())
for t in croniter.croniter_range(
to_datetime(self.start_time),
to_datetime(self.end_time),
time_pattern,
)
]
]
return self.products
[docs]
def is_empty(self) -> bool: # pragma: no cover
"""
Check if there is data within this tile.
Returns
-------
is empty : bool
"""
return len(self.products) == 0
[docs]
def default_read_values(
self,
resampling: Optional[Union[Resampling, str]] = None,
nodatavals: NodataVals = None,
merge_products_by: Optional[str] = None,
merge_method: Optional[MergeMethod] = None,
) -> dict:
"""Provide proper read values depending on user input and defaults."""
if nodatavals is None:
nodatavals = self.default_read_nodataval
merge_products_by = merge_products_by or self.default_read_merge_products_by
merge_method = merge_method or self.default_read_merge_method
return dict(
resampling=(
self.default_read_resampling
if resampling is None
else (
resampling
if isinstance(resampling, Resampling)
else Resampling[resampling]
)
),
nodatavals=nodatavals,
merge_products_by=merge_products_by,
merge_method=merge_method,
read_mask=self.get_read_mask(),
)
[docs]
def get_read_mask(self) -> np.ndarray:
"""
Determine read mask according to input area.
This will generate a numpy array where pixel overlapping the input area
are set True and thus will get filled by the read function. Pixel outside
of the area are not considered for reading.
On staged reading, i.e. first checking the product masks to assess valid
pixels, this will avoid reading product bands in cases the product only covers
pixels outside of the intended reading area.
"""
area = self.area.buffer(self.area_pixelbuffer * self.tile.pixel_x_size)
if area.is_empty:
return np.zeros((self.tile.shape), dtype=bool)
return geometry_mask(
geometries=[mapping(area)],
out_shape=self.tile.shape,
transform=self.tile.transform,
invert=True,
)