import contextlib
import datetime
import itertools
import os
import tempfile
import warnings
from typing import Any, List, Tuple, Union
from zipfile import ZipFile
import ee # type: ignore
import joblib # type: ignore
import rasterio # type: ignore
import requests
import rioxarray # type: ignore
import xarray as xr
from requests.adapters import HTTPAdapter
from tqdm.auto import tqdm # type: ignore
from urllib3.util.retry import Retry
[docs]def Initialize(**kwargs: Any) -> None:
"""Initialize Earth Engine using the high-volume endpoint designed for automated requests.
Parameters
----------
kwargs : Any
Additional keyword arguments passed to ee.Initialize().
"""
ee.Initialize(opt_url="https://earthengine-highvolume.googleapis.com", **kwargs)
def _set_nodata(file: str, nodata: Union[float, int]) -> None:
"""Set the nodata value in the metadata of an image file.
Parameters
----------
file : str
The path to the raster file to set.
nodata : Union[float, int]
The value to set as nodata.
"""
with rasterio.open(file, "r+") as img:
img.nodata = nodata
def _flatten_list(a: List[Any]) -> List[Any]:
"""Flatten a nested list."""
return list(itertools.chain.from_iterable(a))
def _unpack_file(file: str, out_dir: str) -> List[str]:
"""Unpack a ZIP file to a directory.
Parameters
----------
file : str
The path to a ZIP file.
out_dir : str
The path to a directory to unpack files within.
Returns
-------
List[str]
Paths to the unpacked files.
"""
unzipped = []
with ZipFile(file, "r") as zipped:
unzipped += zipped.namelist()
zipped.extractall(out_dir)
return [os.path.join(out_dir, file) for file in unzipped]
def _download_url(url: str, out_dir: str, progress: bool, max_attempts: int) -> str:
"""Download a file from a URL to a specified directory.
Parameters
----------
url : str
The URL address of the element to download.
out_dir : str
The directory path to save the temporary file to.
progress : bool
If true, a progress bar will be displayed to track download progress.
max_attempts : int
The maximum number of times to retry a connection.
Returns
-------
str
The path to the downloaded file.
"""
filename = tempfile.NamedTemporaryFile(mode="w+b", dir=out_dir, delete=False).name
r = _create_retry_session(max_attempts).get(url, stream=True)
try:
r.raise_for_status()
except Exception as e:
# Delete the tempfile if it could not be downloaded
os.remove(filename)
raise e
file_size = int(r.headers.get("content-length", 0))
with open(filename, "w+b") as dst, tqdm(
total=file_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
desc="Downloading",
disable=not progress,
) as bar:
for data in r.iter_content(chunk_size=1024):
size = dst.write(data)
bar.update(size)
return filename
def _create_retry_session(max_attempts: int) -> requests.Session:
"""Create a session with automatic retries.
https://www.peterbe.com/plog/best-practice-with-retries-with-requests
"""
session = requests.Session()
retry = Retry(
total=max_attempts, read=max_attempts, connect=max_attempts, backoff_factor=0.1
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def _dataset_from_files(files: List[str], masked: bool, nodata: int) -> xr.Dataset:
"""Create an xarray.Dataset from a list of raster files."""
das = [_dataarray_from_file(file, masked, nodata) for file in files]
try:
# Allow conflicting values if one is null, take the non-null value
merged = xr.merge(das, compat="no_conflicts")
except xr.core.merge.MergeError:
# If non-null conflicting values occur, take the first value and warn the user
merged = xr.merge(das, compat="override")
warnings.warn(
"Different non-null values were encountered for the same variable at the same time coordinate. The first value was taken."
)
return merged
def _dataarray_from_file(file: str, masked: bool, nodata: int) -> xr.DataArray:
"""Create an xarray.DataArray from a single file by parsing datetimes and variables from the file name.
The file name must follow the format "{dimension}.{coordinate}.{variable}.{extension}".
"""
with rioxarray.open_rasterio(file) as da:
# Load fully into memory rather than reading lazily from disk. This is needed to allow reading from tempfiles
# that will be deleted after the function returns. See https://github.com/corteva/rioxarray/issues/485 and
# https://github.com/aazuspan/wxee/issues/70.
da.load()
dim, coord, var = _parse_filename(file)
da = da.expand_dims({dim: [coord]}).rename(var).squeeze("band").drop_vars("band")
# Mask the nodata values. This will convert int datasets to float.
if masked:
da = da.where(da != nodata)
return da
def _parse_filename(file: str) -> Tuple[str, Union[str, int, datetime.datetime], str]:
"""Parse the dimension, coordinate, and variable from a filename following the format
{id}.{dimension}.{coordinate}.{variable}.{extension}. Return as a tuple.
"""
coord: Union[str, int, datetime.datetime]
basename = os.path.basename(file)
dim, coord_name, variable = basename.split(".")[1:4]
if dim == "time":
coord = _parse_time(coord_name)
else:
coord = int(coord_name)
return (dim, coord, variable)
def _parse_time(time: str) -> Union[datetime.datetime, str]:
"""Parse a time string as it is exported from Earth Engine and return as a datetime.
If the time cannot be parsed, it is returned as a string.
"""
try:
return datetime.datetime.strptime(time, "%Y%m%dT%H%M%S")
except ValueError:
warnings.warn(
f"The time coordinate '{time}' could not be parsed into a valid datetime. Setting as raw value instead."
)
return time
def _millis_to_datetime(millis: str) -> datetime.datetime:
"""Convert a timestamp in UTC milliseconds (e.g. from Earth Engine) to a datetime object."""
return datetime.datetime.utcfromtimestamp(int(millis) / 1000.0)
def _replace_if_null(val: Union[ee.String, ee.Number], replacement: Any) -> Any:
"""Take an Earth Engine object and return either the original non-null object or the given replacement if it is null."""
return ee.Algorithms.If(val, val, replacement)
def _format_date(d: ee.Date) -> ee.String:
"""Format a date using a consistent pattern."""
return ee.Date(d).format("yyyyMMdd'T'HHmmss")
@contextlib.contextmanager
def parallel_tqdm(tqdm_object: tqdm) -> tqdm:
"""Context manager to patch joblib to report into tqdm progress bar given as argument
Reference
---------
https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution
Example
-------
>>> with Parallel(n_jobs=-1) as p:
>>> with parallel_tqdm(tqdm(desc="Progress", total=10)):
>>> urls = p(delayed(f)(x) for x in range(10))
"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __call__(self, *args: Any, **kwargs: Any) -> None:
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
def _normalize(x: ee.Number, minx: ee.Number, maxx: ee.Number) -> ee.Number:
return ee.Number(x).subtract(minx).divide(ee.Number(maxx).subtract(minx))