Source code for wxee.collection

import tempfile
import warnings
from typing import List, Optional

import ee  # type: ignore
import xarray as xr
from joblib import Parallel, delayed  # type: ignore
from tqdm.auto import tqdm  # type: ignore

from wxee import constants
from wxee.accessors import wx_accessor
from wxee.time_series import TimeSeries
from wxee.utils import _dataset_from_files, _flatten_list, parallel_tqdm


@wx_accessor(ee.imagecollection.ImageCollection)
class ImageCollection:
    def __init__(self, obj: ee.imagecollection.ImageCollection):
        """
        Parameters
        ----------
        obj : ee.ImageCollection
            The Image Collection instance extended by this class.
        """
        self._obj = obj

    def _to_image_list(self) -> List[ee.Image]:
        """Convert an image collection to a Python list of images."""
        return [
            ee.Image(self._obj.toList(self._obj.size()).get(i))
            for i in range(self._obj.size().getInfo())
        ]

[docs] def get_image(self, index: int) -> ee.Image: """Return the image at the specified index in the collection. A negative index counts backwards from the end of the collection. Parameters ---------- index : int The index of the image in the collection. Returns ------- ee.Image The image at the given index. """ return ee.Image(self._obj.toList(self._obj.size()).get(index))
[docs] def last(self) -> ee.Image: """Return the last image in the collection. Returns ------- ee.Image The last image in the collection. """ return self.get_image(self._obj.size().subtract(1))
[docs] def to_xarray( self, path: Optional[str] = None, region: Optional[ee.Geometry] = None, scale: Optional[int] = None, crs: str = "EPSG:4326", masked: bool = True, nodata: int = -32_768, num_cores: int = -1, progress: bool = True, max_attempts: int = 10, ) -> xr.Dataset: """Convert an image collection to an xarray.Dataset. The :code:`system:time_start` property of each image in the collection is used to arrange the time dimension, and each image variable is loaded as a separate array in the dataset. Parameters ---------- region : ee.Geometry, optional The region to download the images within. If none is provided, the :code:`geometry` of the image collection will be used. If geometry varies between images in the collection, the region will encompass all images which may lead to very large arrays and download limits. scale : int, optional The scale to download the array at in the CRS units. If none is provided, the :code:`projection.nominalScale` of the images will be used. crs : str, default "EPSG:4326" The coordinate reference system to download the array in. masked : bool, default True If true, nodata pixels in the array will be masked by replacing them with numpy.nan. This will silently cast integer datatypes to float. nodata : int, default -32,768 The value to set as nodata in the array. Any masked pixels will be filled with this value. num_cores : int, default -1 The number of CPU cores to use for parallel operations. Defaults to -1 which will use all available cores. progress : bool, default True If true, a progress bar will be displayed to track download progress. max_attempts: int, default 10 Download requests to Earth Engine may intermittently fail. Failed attempts will be retried up to max_attempts. Must be between 1 and 99. Returns ------- xarray.Dataset A dataset containing all images in the collection with an assigned time dimension and variables set from each image. Raises ------ DownloadError Raised if the image cannot be successfully downloaded after the maximum number of attempts. Examples -------- >>> col = ee.ImageCollection("IDAHO_EPSCOR/GRIDMET").filterDate("2020-09-08", "2020-09-15") >>> col.wx.to_xarray(scale=40000, crs="EPSG:5070", nodata=-9999) """ with tempfile.TemporaryDirectory(prefix=constants.TMP_PREFIX) as tmp: files = self._obj.wx.to_tif( out_dir=tmp, region=region, scale=scale, crs=crs, file_per_band=True, masked=masked, nodata=nodata, num_cores=num_cores, progress=progress, max_attempts=max_attempts, ) ds = _dataset_from_files(files, masked, nodata) if path: msg = ( "The path argument is deprecated and will be removed in a future " "release. Use the `xarray.Dataset.to_netcdf` method instead." ) warnings.warn(category=DeprecationWarning, message=msg) ds.to_netcdf(path, mode="w") return ds
[docs] def to_tif( self, out_dir: str = ".", prefix: Optional[str] = None, region: Optional[ee.Geometry] = None, scale: Optional[int] = None, crs: str = "EPSG:4326", file_per_band: bool = False, masked: bool = True, nodata: int = -32_768, num_cores: int = -1, progress: bool = True, max_attempts: int = 10, ) -> List[str]: """Download all images in the collection to geoTIFF. Image file names will be the :code:`system:id` of each image after replacing invalid characters with underscores, with an optional user-defined prefix. Parameters ---------- out_dir : str, default "." The directory to save the images to. prefix : str, optional A description to prefix to all image file names. If none is provided, no prefix will be added. region : ee.Geometry, optional The region to download the image within. If none is provided, the :code:`geometry` of each image will be used. scale : int, optional The scale to download each image at in the CRS units. If none is provided, the :code:`projection.nominalScale` of each image will be used. crs : str, default "EPSG:4326" The coordinate reference system to download each image in. file_per_band : bool, default False If true, one file will be downloaded per band per image. If false, one multiband file will be downloaded per image instead. masked : bool, default True If true, the nodata value of each image will be set in the image metadata. nodata : int, default -32,768 The value to set as nodata in each image. Any masked pixels in the images will be filled with this value. num_cores : int, default -1 The number of CPU cores to use for parallel operations. Defaults to -1 which will use all available cores. progress : bool, default True If true, a progress bar will be displayed to track download progress. max_attempts: int, default 10 Download requests to Earth Engine may intermittently fail. Failed attempts will be retried up to max_attempts. Must be between 1 and 99. Returns ------- list[str] Paths to downloaded images. Raises ------ DownloadError Raised if the image cannot be successfully downloaded after the maximum number of attempts. Example ------- >>> col = ee.ImageCollection("IDAHO_EPSCOR/GRIDMET").filterDate("2020-09-08", "2020-09-15") >>> col.wx.to_tif(scale=40000, crs="EPSG:5070", nodata=-9999) """ if prefix: self._obj = self._obj.map(lambda img: img.wx._prefix_id(prefix)) imgs = self._to_image_list() n = len(imgs) with Parallel(n_jobs=num_cores, backend="threading") as p: with parallel_tqdm( tqdm(desc="Requesting data", total=n, disable=not progress) ): urls = p( delayed(img.wx._get_url)( region, scale, crs, file_per_band, nodata, max_attempts ) for img in imgs ) with parallel_tqdm( tqdm(desc="Downloading data", total=n, disable=not progress) ): img_urls = zip(imgs, urls) tifs = p( delayed(img.wx._url_to_tif)( url, out_dir, file_per_band, masked, nodata, False, max_attempts ) for img, url in img_urls ) return _flatten_list(tifs)
[docs] def to_time_series(self) -> TimeSeries: """Convert to a :code:`wxee.TimeSeries` collection with associated methods. Returns ------- wxee.TimeSeries The collection as a TimeSeries object. Examples -------- >>> col = ee.ImageCollection("IDAHO_EPSCOR/GRIDMET") >>> ts = col.wx.to_time_series() """ return TimeSeries(self._obj)