"""
Author: Michal Cuadrat-Grzybowski
Date: April 2025
License: Apache License 2.0
Institution: Delft University of Technology

Description:
This module provides utility functions for generating spatial masks for hydrological basins
based on latitude and longitude grids or xarray Datasets. Basin geometries are loaded from
a preprocessed `.npy` file containing shapely polygon definitions.

Functions:
- load_basin_from_npy: Loads basin.
- create_basin_mask_grid: Creates a NumPy mask array for a given basin name.
- basin_bounding_box: Creates a bounding box of maximum and minimum coordinates.
"""
# required packages
import numpy as np
from shapely import MultiPolygon, wkt
import os
from typing import Union, Tuple, Optional, Callable
from shapely.geometry import Point, shape

# retrieve paths
absolute_path = os.path.abspath(__file__)
directory_path = os.path.dirname(absolute_path)


def load_basins_from_npy(npy_path: str = None) -> dict:
    """
    Loads a dictionary of basins from an .npy file.

    :param npy_path: str
        The file path of the .npy file to load. Default is None which defaults to the already saved file.
    :return: dict
        Dictionary where the key is the basin name (str) and the value is a MultiPolygon geometry.
    """
    if npy_path is None:
        npy_path = directory_path+r"\all_basins.npy"
    return np.load(npy_path, allow_pickle=True).item()


def create_basin_mask_grid(
    basin_name: str,
    latitudes: np.ndarray,
    longitudes: np.ndarray,
    basins_path: str = None
) -> Tuple[np.ndarray, object]:
    """
    Efficiently provides a basin mask using vectorized operations after loading basins from a .npy file.

    :param basin_name: str
        The name of the basin.
    :param latitudes: np.ndarray
        2D array of latitude values.
    :param longitudes: np.ndarray
        2D array of longitude values.
    :param basins_path: str
        The file path to the saved .npy file containing basin geometries.
    :return: Tuple[np.ndarray, MultiPolygon]
        Basin mask (1 for points inside the basin, np.nan for others) and the basin polygon.
    """
    all_created_basins = load_basins_from_npy(basins_path)
    basin_polygon = all_created_basins.get(basin_name)

    if basin_polygon is None:
        raise ValueError(f"Basin '{basin_name}' not found in the dataset.")

    # Ensure longitudes and latitudes have the same shape
    if longitudes.shape != latitudes.shape:
        raise ValueError("Longitude and latitude arrays must have the same shape.")

    # Create mesh of (x, y) points
    points = np.array([Point(x, y) for x, y in zip(longitudes.ravel(), latitudes.ravel())])
    mask_values = np.array([basin_polygon.contains(pt) for pt in points]).reshape(latitudes.shape)

    # Create mask with 1 inside the basin and np.nan outside
    mask = np.full(latitudes.shape, np.nan)
    mask[mask_values] = 1

    return mask, basin_polygon


def basin_bounding_box(basin_shape: MultiPolygon, N: float):
    """
    Computes the bounding box for a basin shape and extends it by N degrees.

    :param basin_shape: MultiPolygon
        The shape of the basin represented as a MultiPolygon.
    :param N: float
        The degree of extension to add to the minimum and maximum latitude and longitude.

    :return: tuple[float, float, float, float]
        A tuple containing (lon_in_min, lon_in_max, lat_in_min, lat_in_max).
    """
    # Get the bounds of the MultiPolygon (real min/max lon and lat)
    min_lon, min_lat, max_lon, max_lat = basin_shape.bounds

    # Adjust the bounds by N degrees
    lon_in_min = min_lon - N
    lon_in_max = max_lon + N
    lat_in_min = min_lat - N
    lat_in_max = max_lat + N

    return lon_in_min, lon_in_max, lat_in_min, lat_in_max
