from __future__ import annotations
import logging
from typing import Any, List, Literal, Optional, Set
import numpy as np
import numpy.ma as ma
from pystac import Item
import xarray as xr
from mapchete import Timer
from mapchete.io.raster import ReferencedRaster
from mapchete.path import MPath, MPathLike
from mapchete.protocols import GridProtocol
from mapchete.types import Bounds, NodataVals
from numpy.typing import DTypeLike
from rasterio.enums import Resampling
from shapely.geometry import shape
from mapchete_eo.array.convert import to_dataarray
from mapchete_eo.io import get_item_property, item_to_np_array
from mapchete_eo.protocols import EOProductProtocol
from mapchete_eo.settings import mapchete_eo_settings
from mapchete_eo.types import BandLocation
logger = logging.getLogger(__name__)
[docs]
class EOProduct(EOProductProtocol):
"""
Wrapper class around a STAC Item which provides data reading capabilities.
"""
id: str
default_dtype: DTypeLike = np.uint16
_item: Optional[Item] = None
def __init__(self, item: Item):
self.item_dict = item.to_dict()
self.__geo_interface__ = self.item.geometry
self.bounds = Bounds.from_inp(shape(self))
self.crs = mapchete_eo_settings.default_catalog_crs
self._item = None
self.id = item.id
def __repr__(self):
return f"<EOProduct product_id={self.item.id}>"
[docs]
def clear_cached_data(self):
pass
@property
def item(self) -> Item:
if not self._item:
self._item = Item.from_dict(self.item_dict)
return self._item
[docs]
@classmethod
def from_stac_item(self, item: Item, **kwargs) -> EOProduct:
return EOProduct(item)
[docs]
def get_mask(self) -> ReferencedRaster: ...
[docs]
def read(
self,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
grid: Optional[GridProtocol] = None,
resampling: Resampling = Resampling.nearest,
nodatavals: NodataVals = None,
x_axis_name: str = "x",
y_axis_name: str = "y",
raise_empty: bool = True,
**kwargs,
) -> xr.Dataset:
"""
Read bands and assets into an xarray.Dataset.
Args:
assets: List of asset names.
eo_bands: List of EO band names.
grid: Target grid protocol.
resampling: Resampling algorithm.
nodatavals: Custom nodata values.
x_axis_name: Name of X axis in output.
y_axis_name: Name of Y axis in output.
raise_empty: Raise exception if no data is found.
Returns:
xr.Dataset: Dataset with assets as data variables.
"""
# developer info: all fancy stuff for special platforms like Sentinel-2
# should be implemented in the respective read_np_array() methods which get
# called by this method. No need to apply masks etc. here too.
if isinstance(nodatavals, list):
nodataval = nodatavals[0]
elif isinstance(nodatavals, float):
nodataval = nodatavals
else:
nodataval = nodatavals
assets = assets or []
eo_bands = eo_bands or []
data_var_names = assets or eo_bands
return xr.Dataset(
data_vars={
data_var_name: to_dataarray(
asset_arr,
x_axis_name=x_axis_name,
y_axis_name=y_axis_name,
name=data_var_name,
attrs=dict(item_id=self.item.id),
)
for asset_arr, data_var_name in zip(
self.read_np_array(
assets=assets,
eo_bands=eo_bands,
grid=grid,
resampling=resampling,
nodatavals=nodatavals,
raise_empty=raise_empty,
**kwargs,
),
data_var_names,
)
},
coords={},
attrs=dict(self.item.properties, id=self.item.id, _FillValue=nodataval),
)
[docs]
def read_np_array(
self,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
grid: Optional[GridProtocol] = None,
resampling: Resampling = Resampling.nearest,
nodatavals: NodataVals = None,
raise_empty: bool = True,
apply_offset: bool = True,
**kwargs,
) -> ma.MaskedArray:
"""
Read assets or EO bands into a MaskedArray.
Args:
assets: List of asset names.
eo_bands: List of EO band names.
grid: Target grid.
resampling: Resampling method.
nodatavals: Nodata values.
raise_empty: Raise if empty.
apply_offset: Apply offset/scale metadata if present.
Returns:
ma.MaskedArray: Output array.
"""
assets = assets or []
eo_bands = eo_bands or []
bands = assets or eo_bands
logger.debug("%s: reading assets %s over %s", self, bands, grid)
with Timer() as t:
out = item_to_np_array(
self.item,
self.assets_eo_bands_to_band_locations(assets, eo_bands),
grid=grid,
resampling=resampling,
nodatavals=nodatavals,
raise_empty=raise_empty,
apply_offset=apply_offset,
)
logger.debug("%s: read in %s", self, t)
return out
[docs]
def empty_array(
self,
count: int,
grid: GridProtocol,
fill_value: int = 0,
dtype: Optional[DTypeLike] = None,
) -> ma.MaskedArray:
shape = (count, *grid.shape)
dtype = dtype or self.default_dtype
return ma.MaskedArray(
data=np.full(shape, fill_value=fill_value, dtype=dtype),
mask=np.ones(shape, dtype=bool),
fill_value=fill_value,
)
[docs]
def get_property(self, property: str) -> Any:
return get_item_property(self.item, property)
[docs]
def eo_bands_to_band_location(self, eo_bands: List[str]) -> List[BandLocation]:
return eo_bands_to_band_locations(self.item, eo_bands)
[docs]
def assets_eo_bands_to_band_locations(
self,
assets: Optional[List[str]] = None,
eo_bands: Optional[List[str]] = None,
) -> List[BandLocation]:
assets = assets or []
eo_bands = eo_bands or []
if assets and eo_bands:
raise ValueError("assets and eo_bands cannot be provided at the same time")
if assets:
return [BandLocation(asset_name=asset) for asset in assets]
elif eo_bands:
return self.eo_bands_to_band_location(eo_bands)
else:
raise ValueError("assets or eo_bands have to be provided")
[docs]
def eo_bands_to_band_locations(
item: Item,
eo_bands: List[str],
role: Literal["data", "reflectance", "visual"] = "data",
) -> List[BandLocation]:
"""
Map EO band names to asset locations.
Args:
item: STAC Item.
eo_bands: List of common band names.
role: Functional role of the assets.
Returns:
List[BandLocation]: List of location objects.
"""
return [find_eo_band(item, eo_band, role=role) for eo_band in eo_bands]
[docs]
def find_eo_band(
item: Item,
eo_band_name: str,
role: Literal["data", "reflectance", "visual"] = "data",
) -> BandLocation:
"""
Tries to find the location of the most appropriate band using the EO band name.
This function looks into all assets and all eo bands for the given name and role.
"""
results = []
for asset_name, asset in item.assets.items():
# search in eo:bands and alternatively in bands for eo:common_name
for band_index, band_info in enumerate(
asset.extra_fields.get("eo:bands", asset.extra_fields.get("bands", [])), 1
):
if (
# if name matches eo band name
(
eo_band_name == band_info.get("name")
or eo_band_name == band_info.get("eo:common_name")
)
# if role is given, make sure it matches with desired role
and (asset.roles is None or role in asset.roles)
):
results.append(
BandLocation.from_asset(
name=asset_name,
band_index=band_index,
asset=asset,
)
)
if len(results) == 0:
raise KeyError(f"EO band {eo_band_name} not found in item assets")
elif len(results) == 1:
return results[0]
# if results are ambiguous, further filter them
else:
# only use locations which seem to have the original resolution
for matches in [_asset_name_equals_eo_name, _is_original_sampling]:
filtered_results = [
band_location for band_location in results if matches(band_location)
]
if len(filtered_results) == 1:
return filtered_results[0]
else: # pragma: no cover
raise ValueError(
f"EO band '{eo_band_name}' found in multiple assets: {', '.join(map(str, results))}"
)
def _asset_name_equals_eo_name(band_location: BandLocation) -> bool:
return band_location.asset_name == band_location.eo_band_name
def _is_original_sampling(band_location: BandLocation) -> bool:
return band_location.roles == [] or "sampling:original" in band_location.roles
[docs]
def add_to_blacklist(path: MPathLike, blacklist: Optional[MPath] = None) -> None:
blacklist = blacklist or mapchete_eo_settings.blacklist
if blacklist is None:
return
blacklist = MPath.from_inp(blacklist)
path = MPath.from_inp(path)
# make sure paths stay unique
if str(path) not in blacklist_products(blacklist):
logger.debug("add path %s to blacklist", str(path))
try:
with blacklist.open("a") as dst:
dst.write(f"{path}\n")
except FileNotFoundError:
with blacklist.open("w") as dst:
dst.write(f"{path}\n")
[docs]
def blacklist_products(blacklist: Optional[MPathLike] = None) -> Set[str]:
blacklist = blacklist or mapchete_eo_settings.blacklist
if blacklist is None:
raise ValueError("no blacklist is defined")
blacklist = MPath.from_inp(blacklist)
try:
return set(blacklist.read_text().splitlines())
except FileNotFoundError:
logger.debug("%s does not exist, returning empty set", str(blacklist))
return set()