"""
Author: Michal Cuadrat-Grzybowski
Date: April 2025
License: Apache License 2.0
Institution: Delft University of Technology

Description:
This module provides utility functions for generating heatmaps and extracting time-series from the TUD-5d GRACE product.
"""

# required packages
import matplotlib.pyplot as plt
import cartopy
import cmweather
import matplotlib.colors as colors
from cartopy import crs as ccrs, feature as cfeature
import os
from calendar import monthrange
import pandas as pd
import numpy as np
import re
import basin_mask_utilities as basin
from shapely.geometry import shape, Point, MultiPolygon
import datetime
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from netCDF4 import Dataset
import xarray as xr
import matplotlib.dates as mdates
from typing import Union, Tuple, Optional, Callable


# main directory
utility_dir_ = os.path.dirname(os.path.realpath(__file__))

print('--- Imports done ---')

# Define regions of the world
region_coordinates = {
    "australia": [110, 155+0.5, -45, -9.5],
    'east_australia':     [135, 155, -50, 0],
    "south_america": [-80, -35, -56, 13],
    "north_australia": [125, 145, -20, -10.5],
    "sub_tropical_gyre": [-76, -27.5, -60.5, -25],  # Approximate coordinates for the region
    "zapiola": [-55, -27.5, -55, -35],  # Approximate coordinates of the Zapiola anticyclonic region
    "north_america": [-169, -50, 7, 85],
    "greenland": [-75, -12, 59, 85],
    "antarctica": [-180, 180, -90, -60],  # Covers the entire continent
    "africa": [-17.54, 51.27, -34.83, 37.21],
    "world": [-180, 180, -90, 90],
    "GBM":  [68, 100, 6, 40],
    'thailand': [90, 115, 5, 20.5],
    'cambodia_vietnam': [90, 115, 5, 20.5],
    'tohoku':   [120, 155, 20, 50],
    'sumatra':  [85, 110, -10, 20]
}


def plot_cartopy_map(ax: plt.subplot, high_quality=True, scale='low', add_greenland=True) -> tuple[classmethod]:
    """
    Generates a plot of globe with cartopy GSHHS.
    :param ax: axis (matplotlib.plt.subplot) with projection not set to None.
        subplot axis object in which map is plotted
    :param high_quality: boolean
        Boolean which decides if both GSHSS and NaturalEarth (NE) are used for optimal resolution, or only NE.
        (default: True)
    :param scale: string
        resolution specific to cartopy GSHHS (default: 'low')
    :param add_greenland: bool (optional)
    :return: class method
        grid-line (is not needed for plotting, but can be useful to change its parameters)
    """
    if high_quality:
        # add high quality GSHHS coastlines without Antarctica
        coast_1 = cfeature.GSHHSFeature(scale=scale, levels=[1])
        ax.add_feature(coast_1)
        if add_greenland is True:
            ax.add_feature(cfeature.COASTLINE, alpha=1, linewidth=1, edgecolor='k', facecolor='none')
        ax.add_feature(cfeature.RIVERS, alpha=1, linewidth=1, edgecolor='k', facecolor='none')
    elif high_quality is False:
        coast_1 = cfeature.NaturalEarthFeature(category='physical', scale='110m', name='coastline', facecolor='none')
        ax.add_feature(coast_1, alpha=1, linewidth=1)
        ax.add_feature(cartopy.feature.NaturalEarthFeature(
        category='physical', name='rivers_lake_centerlines',
        scale='110m', facecolor='none', edgecolor='k'), alpha=1, linewidth=1, edgecolor='k', facecolor='none')
        ax.add_feature(cfeature.LAKES, alpha=1, linewidth=1, edgecolor='k')
    # add gridlines
    gl = ax.gridlines(
        draw_labels=True, alpha=0.3)
    # remove right labels
    gl.right_labels = False
    # specifies the size of xticks
    gl.xlabel_style = {'size': 11}
    gl.ylabel_style = {'size': 11}

    return gl, coast_1


def set_zero_to_white(cmap_name, vmin, vmax):
    """
    Modify a colormap to set the value 0 to white for use in a pcolormesh plot.

    Parameters:
    cmap_name (str): The name of the colormap (e.g., 'viridis', 'plasma').
    data (numpy.ndarray): The data to be plotted, used for normalization.

    Returns:
    custom_cmap: A colormap where the value 0 is set to white.
    norm: The normalization based on the data range.
    """
    # Get the colormap
    cmap = plt.get_cmap(cmap_name)

    # Normalize data
    norm = plt.Normalize(vmin=vmin, vmax=vmax)

    # Convert colormap to a list of colors
    cmaplist = [cmap(norm(i)) for i in np.linspace(vmin, vmax, cmap.N*2)]

    # Find the index where the value is closest to 0, and set that to white
    zero_value = norm(0)  # Get normalized value of 0
    zero_index = int(zero_value * (cmap.N - 1))  # Get the index corresponding to 0
    cmaplist[zero_index] = (1.0, 1.0, 1.0, 1.0)  # Set to white

    # Create the new colormap
    custom_cmap = colors.ListedColormap(cmaplist)

    return custom_cmap, norm


def create_basin_mask_for_xrarray(ds: xr.Dataset, basin_name: str, basins_path: str = None) -> (
        Tuple[xr.DataArray, np.ndarray]):
    """
    Creates a mask for a given basin name from an xarray dataset by loading the basin from an .npy file.

    :param ds: xr.Dataset
        The dataset containing lat/lon coordinates.
    :param basin_name: str
        The name of the basin to load.
    :param basins_path: str
        The file path to the saved .npy file containing basin geometries.
    :return: Tuple (xr.DataArray, ndarray)
        A boolean mask where True represents the basin region.
    """
    basins = basin.load_basins_from_npy(basins_path)
    basin_geometry = basins.get(basin_name)
    if basin_geometry is None:
        raise ValueError(f"Basin '{basin_name}' not found in the dataset.")

    lon, lat = np.meshgrid(ds["lon"].values, ds["lat"].values)
    points = [Point(x, y) for x, y in zip(lon.ravel(), lat.ravel())]
    mask_values = np.array([basin_geometry.contains(p) for p in points]).reshape(lon.shape)
    # Compute bounding box
    lon_min, lat_min, lon_max, lat_max = basin_geometry.bounds
    bbox = (lon_min, lon_max, lat_min, lat_max)

    return xr.DataArray(mask_values, coords={"lat": ds["lat"], "lon": ds["lon"]}, dims=["lat", "lon"]), bbox


def plot_heatmaps_nc_with_map(nc_file, labels, vmins, vmaxs, region=None, masks=None, cmaps=None,
                               projection_type="Robinson", save_options=(False, ''), show_plot=True, fig_size=(20, 8),
                               date_string=None, bool_rms=False, grid_files=None, var_str='hf_lgd', **optional):
    """
    Plots heatmaps for multiple grid files side by side, overlaying a map for each statistic.
    :param labels: list of str
        List of labels for the colour bars.
    :param vmins: list of float
        List of minimum values for the colour scales.
    :param vmaxs: list of float
        List of maximum values for the colour scales.
    :param region: list or tuple
        Rectangular domain as follows: (lat_min, lat_max, lon_min, lon_max). Default: entire domain.
    :param masks: list of np.ndarray or None
        Optional list of mask_plot (boolean arrays) to apply to the grid data.
    :param cmaps: list of str
        List of colour maps to be used for the heatmaps.
    :param projection_type: str
        Type of projection to use for the maps. Default: "Robinson".
        Options: "Robinson", "PlateCarree", "Mercator", etc.
    :param save_options: tuple
        Optional tuple input containing a boolean and a string. The first element is the boolean which decides
        if saving is required, and the second element is the output path. Default: (False, '').
    :param show_plot: boolean
        Optional boolean input to show or not the heatmap(s). Default: True.
    :param fig_size: tuple
        Size of the figure to be drawn or saved. Default: (21, 4).
    :param date_string: str
        The date string in the format "day/month/year" to select the appropriate grid file.
    :param bool_rms: bool
        If True, compute and plot the total RMS over time instead of using a single date.
    :param optional:
        Optional inputs like 'add_event' for additional event plotting.
    """
    # Select the map projection
    if region is None:
        region = [-180.0, 180.0, -90.0, 90.0]
    projections = {
        "Robinson": ccrs.Robinson(),
        "PlateCarree": ccrs.PlateCarree(),
        "Mercator": ccrs.Mercator(),
        None: ccrs.PlateCarree()
    }
    projection = projections[projection_type]
    markers = ['o', 'x', '.', '*', 'v', 's', 'p', 'P']

    # Initialize the figure
    if 'fig' not in optional.keys():
        fig = plt.figure(figsize=fig_size)
    if 'fig' in optional.keys():
        fig = optional['fig']

    element_index = 1
    index_2 = 0

    # Extract the necessary variables (assuming 'time', 'lat', 'lon', and 'data' are present)
    time = nc_file['time'].values  # Already a NumPy array in datetime64
    lat = nc_file['lat'].values
    lon = nc_file['lon'].values
    data = nc_file[var_str].values  # Extract data as a NumPy array

    if bool_rms:
        grid_values = np.sqrt(np.nanmean(data ** 2, axis=0))  # Compute total RMS in time
        if "date_start" in optional.keys():
            add_string1, add_string2 = (optional['date_start'].strftime("%d/%m/%Y"),
                                        optional['date_end'].strftime("%d/%m/%Y"))
        else:
            add_string1, add_string2 = '', ''
        title_text = f"Total RMS over {add_string1} - {add_string2}"
    if bool_rms is False:
        # # Convert date_string to match NetCDF format
        if isinstance(date_string, str):
            given_date = np.datetime64(datetime.datetime.strptime(date_string, '%d/%m/%Y'))
        else:
            given_date = date_string
        selection = (
            nc_file.sel(time=given_date, method="nearest", tolerance=np.timedelta64(1, 'D')))
        grid_values = selection[var_str].values
        given_date = selection['time'].values.astype('M8[ms]').astype(datetime.datetime).strftime("%d/%m/%Y")
        if var_str == 'hf_lgd':
            title_text = f"High-frequency Geo-fit LGD v0 - {given_date}"
        else:
            title_text = f"TUD-5d (v0) - {given_date}"
    # Create the plot
    if 'ax' not in optional.keys():
        ax = plt.subplot(1, 1, element_index, projection=projection)
    else:
        ax = optional['ax']
    # Apply region masking
    lon_min, lon_max, lat_min, lat_max = region
    lat_mask = (lat >= lat_min) & (lat <= lat_max)
    lon_mask = (lon >= lon_min) & (lon <= lon_max)
    grid_values = grid_values[np.ix_(lat_mask, lon_mask)]

    lat = lat[lat_mask]
    lon = lon[lon_mask]

    # reverse in case you purely load axis.
    if 'reverse' in optional.keys():
        if optional['reverse'][0] is True:
            grid_values = np.flip(grid_values, axis=0)

    # create a grid of latitudes-longitudes
    lons, lats = np.meshgrid(lon, lat)
    # Apply mask if provided
    if masks is not None:
        basin_mask, basin_polygon = basin.create_basin_mask_grid(masks, lats, lons)
        if basin_mask.shape != grid_values.shape:
            basin_mask = basin_mask[:grid_values.shape[0], :grid_values.shape[1]]
        grid_values = basin_mask * grid_values
        if isinstance(basin_polygon, MultiPolygon):
            # Iterate over each Polygon in the MultiPolygon
            for polygon in basin_polygon.geoms:
                # Extract coordinates for each Polygon
                x, y = polygon.exterior.xy
                color_basin = 'black'
                ax.plot(x, y, color=color_basin, linewidth=2,
                        transform=ccrs.PlateCarree())  # Outline of the Polygon

    # add (if any) additional event
    if 'add_event' in optional.keys():
        if optional['add_event'] is not None:
            index_ = 0
            for add_event in optional['add_event']:
                label_event, pos_event, lon_event, lat_event, colors_ = add_event
                if pos_event == element_index - 1 or element_index - 1 in pos_event:
                    ax.plot(lon_event, lat_event, label=label_event,
                            color=colors_[index_2], marker=markers[index_], markersize=6,
                            transform=ccrs.PlateCarree())
                    index_2 += 1
                    if element_index - 1 == max(pos_event) and "ax" not in optional.keys():
                        ax.legend(fontsize=11)
                    index_ += 1
    # Plot the heatmap
    mesh = ax.pcolormesh(lon, lat, grid_values,
                         cmap=cmaps[element_index-1], vmin=vmins[element_index-1], vmax=vmaxs[element_index-1],
                         transform=ccrs.PlateCarree())

    # Add the map overlay
    plot_cartopy_map(ax, high_quality=True, scale='low')

    ax.set_extent(region, crs=ccrs.PlateCarree())
    ax.set_title(title_text, fontsize=20)

    # Add the color bar
    data_min = np.nanmin(grid_values)
    data_max = np.nanmax(grid_values)
    vmin = vmins[element_index - 1]
    vmax = vmaxs[element_index - 1]
    extend = 'neither'
    if data_min < vmin and data_max > vmax:
        extend = 'both'
    elif data_min < vmin:
        extend = 'min'
    elif data_max > vmax:
        extend = 'max'
    cbar = fig.colorbar(mesh, ax=ax, orientation='vertical', extend=extend)
    cbar.set_label(labels[element_index-1], fontsize=14)
    cbar.ax.tick_params(labelsize=11)

    # Increment element index
    element_index += 1

    plt.tight_layout()

    # Save the plot if needed
    if save_options[0]:
        plt.savefig(save_options[1])

    # Show the plot
    if show_plot and 'ax' not in optional.keys():
        plt.show()
    if show_plot is False and 'ax' not in optional.keys():
        plt.close()


def extract_time_series(
        date_start: str,
        date_end: str,
        var_name: str,
        selection: Union[Tuple[float, float], Tuple[float, float, float, float], str],
        mask_plot: str = None,
        plot: bool = False,
        date_string: str = None,
        projection_type: str = "PlateCarree",
        region: tuple = None,
        method: str = 'nearest',
        file_path: str = None,
        bool_rms: bool = False,
        **optional) -> pd.DataFrame:
    """
    Extracts a time series from a NetCDF file for either a single grid cell,
    a specified rectangular region, or a masked area, using area-weighted averaging.

    Parameters:
    -----------
    date_start: str
        String date specifying the beginning of the time-series. Format: "day/month/year".
    date_end: str
        String date specifying the end of the time-series. Format: "day/month/year".
    var_name : str
        Name of the variable to extract.
    selection : tuple or str
        - (lon, lat): Extracts the nearest grid cell.
        - (lon_min, lon_max, lat_min, lat_max): Extracts an area-weighted time series.
        - "mask_name": Applies a custom mask for computing spatial statistics.
    mask_plot : str
        String name of river basin mask (only for heatmap plotting). Default is None.
    plot : bool, optional
        Whether to plot the extracted time series.
    date_string : str, optional
        Datetime in string format ("day/month/year") to be highlighted in the time-series.
    projection_type : str, optional
        Type of projection to be used in heatmap. Default is "PlateCarree".
    region : tuple, optional
        Rectangular region to extract min, max longitudes and latitudes for plotting. Default is None.
    method : str, optional
        Method used for spatial interpolation. Options: ('nearest', 'linear'). Default is 'nearest'.
    file_path : str, optional
        Path to the NetCDF file which overwrites the reading from start to end date.
    bool_rms: bool, optional
        If True, compute and plot the total RMS over time instead of using a single date.
    Returns:
    --------
    pd.DataFrame
        A DataFrame containing time and extracted variable values.
    """
    # possible projections
    projections = {
        "Robinson": ccrs.Robinson(),
        "PlateCarree": ccrs.PlateCarree(),
        "Mercator": ccrs.Mercator(),
        "Orthographic": ccrs.Orthographic(),
        "LambertConformal": ccrs.LambertConformal(),
        "AlbersEqualArea": ccrs.AlbersEqualArea(),
        "Mollweide": ccrs.Mollweide(),
        "AzimuthalEquidistant": ccrs.AzimuthalEquidistant(),
        "Gnomonic": ccrs.Gnomonic(),
        "Europe": ccrs.EuroPP(),
        "AlbersEqualAreaGreenland": ccrs.AlbersEqualArea(central_longitude=-45, central_latitude=72),
        "LambertConformalGreenland": ccrs.LambertConformal(central_longitude=-45, central_latitude=72),
        "PolarStereographicSouth": ccrs.Stereographic(central_latitude=-90),  # Ideal for Antarctica
        None: ccrs.PlateCarree()  # Default projection
    }
    projection = projections[projection_type]

    # Get base directory (one level above current working directory)
    base_dir = os.path.dirname(os.path.dirname(os.getcwd()))
    path_to_tot = str(base_dir) + "\TUD-L2B-5dayEWH_2002_2016.nc"
    if var_name == "hf_ewh_data":
        path_to_tot = str(base_dir)+"\TUD-L2B-5dayEWH_2002_2016.nc"
    else:
        ValueError("Not correct variable name used.")

    # Open the dataset
    if file_path is not None:  # here you overwrite opening the total file
        ds = xr.open_dataset(file_path)
    else:
        ds = xr.open_dataset(path_to_tot)

    if date_string is not None:
        # replace date_string from simple string to datetime object (noon)
        date_string = pd.to_datetime(date_string, format='%d/%m/%Y').replace(hour=12, minute=0, second=0)
    else: date_string = None  # correction from earliest version

    # checking date
    if date_string not in ds["time"].to_numpy() and date_string is not None:
        import warnings
        warnings.warn(
            f"Date not included in time-series. Available dates: {ds['time'].to_numpy()}. "
            f"Taking nearest date.",
            UserWarning
        )
    else: ...

    # (linearily) interpolate to 1-Day res. if wished for
    if "interpolate" in optional.keys():
        # Now interpolate along the time dimension
        new_time = pd.date_range(start=ds['time'].min().values, end=ds['time'].max().values, freq='1D')
        # Interpolate the data along the time dimension
        ds = ds.interp(time=new_time)
    else: ...

    # Convert input dates to datetime format
    date_start = pd.to_datetime(date_start, format='%d/%m/%Y')
    date_end = pd.to_datetime(date_end, format='%d/%m/%Y')

    # Filter dataset based on the time coordinate
    ds = ds.where((ds.time >= date_start) & (ds.time <= date_end), drop=True)

    # Access the variable you want to filter (replace 'var_name' with your variable name)
    var_data = ds[var_name]

    # Outlier removal (optional) set values with absolute value > outlier to NaN
    if 'outlier' in optional.keys():
        filtered_data = var_data.where(np.abs(var_data) <= optional['outlier'], np.nan)
        ds[var_name] = filtered_data
    else: ...

    # Compute grid cell area (assuming lat/lon in degrees)
    lat = ds['lat']
    lon = ds['lon']

    # Approximate area using spherical Earth model (in km²)
    R = 6371  # Earth's radius in km
    lat_rad = np.radians(lat)
    lon_rad = np.radians(lon)
    dlat = np.abs(np.gradient(lat_rad))
    dlon = np.abs(np.gradient(lon_rad))
    lat_rad, lon_rad = np.meshgrid(lat_rad, lon_rad)
    areas = (R**2) * dlat[:, None] * dlon[None, :] * np.cos(lat_rad.T)

    # Add area as a DataArray
    ds["grid_area"] = (("lat", "lon"), areas)
    add_feature = None

    # Case 1: Single grid cell (nearest lat/lon)
    if isinstance(selection, tuple) and len(selection) == 2:
        lon, lat = selection
        data = ds[var_name].sel(lat=float(lat), lon=float(lon), method="nearest")
        if region is None:
            region = [lon-15, lon+15, lat-15, lat+15]

        # Convert to DataFrame
        df = data.to_dataframe().reset_index()
        add_feature = [("", [0], lon, lat, ["tab:orange"])]

    # Case 2: Regional area-weighted averaging
    elif isinstance(selection, tuple) and len(selection) == 4:
        lon_min, lon_max, lat_min, lat_max = selection
        lat, lon = [lat_max, lat_max, lat_min, lat_min, lat_max], [lon_min, lon_max, lon_max, lon_min, lon_min]
        var_region = ds[var_name].sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
        area_region = ds["grid_area"].sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
        add_feature = [("", [0], lon, lat, ["tab:orange"])]

    # Case 3: Mask-based area-weighted extraction
    elif isinstance(selection, str):
        if mask_plot is None:
            raise ValueError("Mask function must be provided when using a named mask.")
        mask, basin_box = create_basin_mask_for_xrarray(ds, selection)
        lon_min, lon_max, lat_min, lat_max = basin_box
        var_region = ds[var_name].where(mask)
        area_region = ds["grid_area"].where(mask)

    else:
        raise ValueError("Invalid selection input format. "
                         "Use (lon, lat), (lon_min, lon_max, lat_min, lat_max), or a string mask name.")

    # perform statistical computations for a region
    if len(selection) == 4 or isinstance(selection, str):
        if "outlier_time" in optional.keys():
            # Min. Max. over time
            var_min_over_time = var_region.min(dim=["lat", "lon"])
            var_max_over_time = var_region.max(dim=["lat", "lon"])

            # Compute trend-adjusted historical min/max
            var_min_hist = var_min_over_time.median(dim="time")
            var_max_hist = var_max_over_time.median(dim="time")
            var_median_hist = var_region.mean(dim=["lat", "lon"]).median(dim="time")
            var_iqr_hist = var_max_hist - var_min_hist

            # Define outlier thresholds based on the trend-adjusted range
            lower_bound = var_median_hist - 1.5 * var_iqr_hist
            upper_bound = var_median_hist + 1.5 * var_iqr_hist

            # Apply filter to remove outliers
            mask_outliers = (var_region >= lower_bound) & (var_region <= upper_bound)
            var_region_filtered = var_region.where(mask_outliers)
            area_region_filtered = area_region.where(mask_outliers)
        else:
            var_region_filtered = var_region
            area_region_filtered = area_region

        # Compute refined weighted mean and standard deviation
        weighted_mean = (var_region_filtered * area_region_filtered).sum(dim=["lat", "lon"]) / area_region_filtered.sum(
            dim=["lat", "lon"])
        weighted_std = np.sqrt(
            ((var_region_filtered - weighted_mean) ** 2 * area_region_filtered).sum(dim=["lat", "lon"])
            / area_region_filtered.sum(dim=["lat", "lon"]))

        # Min. Max. over time
        var_min_over_time = var_region_filtered.min(dim=["lat", "lon"])
        var_max_over_time = var_region_filtered.max(dim=["lat", "lon"])

        # unweighted mean and std
        if "unweighted" in optional.keys():
                if optional['unweighted'] is True:
                    weighted_mean = var_region_filtered.mean(dim=["lat", "lon"], skipna=True)
                    weighted_std = var_region_filtered.std(dim=["lat", "lon"], skipna=True)

        df = pd.DataFrame({"time": ds["time"].to_numpy(), var_name: weighted_mean.to_numpy(),
                           var_name + '_std': weighted_std.to_numpy(),
                           'min': var_min_over_time.to_numpy(), "max": var_max_over_time.to_numpy()})

        # for plotting
        if region is None:
            region = [lon_min - 5, lon_max + 5,
                      lat_min - 5, lat_max + 5]

    # Plot if requested
    if plot:
        # Create a figure with two subplots (side by side)
        fig = plt.figure(figsize=(15, 5))  # Projection here
        if var_name == 'hf_lgd':
            labels = [r'[nm/s$^2$]']
        else:
            labels = [r'[cm]']

        if date_string is not None or bool_rms is True:
            # define labels, vmin, vmax, cmap
            if 'vmin' not in optional.keys():
                vmins = [-27.5]
            else:
                vmins = [optional['vmin']]
            if 'vmax' not in optional.keys():
                vmaxs = [27.5]
            else:
                vmaxs = [optional['vmax']]
            if 'cmap' not in optional.keys():
                cmaps = ['seismic']
            else:
                cmaps = [optional['cmap']]
            # change settings if RMS over time is wished.
            if bool_rms is True:
                if 'vmin' not in optional.keys() or 'vmax' not in optional.keys():
                    vmins = [0]
                    vmaxs = [20]
                if 'cmap' not in optional.keys():
                    cmaps = ['HomeyerRainbow']

            # define sub-plots
            ax1 = plt.subplot(121, projection=projection) # plot with map
            ax2 = plt.subplot(122) # time-series plot

            ############################################################################################################
            # change "time" to datetime format
            df["time"] = pd.to_datetime(df["time"], format="%d/%m/%Y")  # Adjust the format as needed
            # Initial Plot (with markers)
            line, = ax2.plot(df["time"], df[var_name], marker="o", linestyle="-")

            # ZOOM INTERACTIVITY
            # Define zoom threshold (in days)
            zoom_threshold = 3 * 365  # Show markers when zoomed in to < 3 year

            def update_markers(event):
                """Callback function to update markers based on zoom level."""
                xlim = ax2.get_xlim()  # Get current x-axis limits
                zoom_range = xlim[1] - xlim[0]  # Compute time range visible

                if zoom_range < zoom_threshold:
                    line.set_marker("o")  # Show markers when zoomed in
                else:
                    line.set_marker("")  # Hide markers when zoomed out
                fig.canvas.draw_idle()  # Redraw plot

            # Connect zoom event
            ax2.callbacks.connect("xlim_changed", update_markers)

            # if region plot also standard deviation, min. and max.
            if len(selection) == 4 or isinstance(selection, str):
                ax2.plot(df["time"], df['max'], linestyle="-", label='Max.', color='r')
                # Plot ±1 sigma as a shaded region
                ax2.fill_between(df["time"], df[var_name] - df[var_name + "_std"],
                                 df[var_name] + df[var_name + "_std"], color="gray", alpha=0.3, label="±1σ")
                ax2.plot(df["time"], df['min'], linestyle="-", label='Min.', color='tab:purple')

            if var_name == 'hf_lgd':
                # Reverse the y-axis (negative LGDs showing positive mass anomalies)
                ax2.invert_yaxis()

            # Improve date visibility
            ax2.xaxis.set_major_locator(mdates.AutoDateLocator())  # Auto spacing
            ax2.xaxis.set_major_formatter(mdates.DateFormatter("%d-%m-%Y"))  # Format as DD-MM-YYYY
            # Rotate and align x-ticks only on ax2
            ax2.tick_params(axis='x', rotation=15)  # Rotate and align right for ax2

            # plot date_string
            if date_string is not None:
                # # Convert date_string to match NetCDF format
                if isinstance(date_string, str):
                    given_date = np.datetime64(datetime.datetime.strptime(date_string, '%d/%m/%Y'))
                else:
                    given_date = date_string
                t_sel = (
                    ds.sel(time=given_date, method="nearest", tolerance=np.timedelta64(1, 'D')))
                given_date = t_sel['time'].values.astype('M8[ms]').astype(datetime.datetime)
                d_string = given_date.strftime("%d/%m/%Y")
                ax2.axvline(given_date,
                            color='tab:orange', linestyle='--', label=f"{d_string}")

            # set xlims
            if file_path is not None:
                time_start = pd.Timestamp(df["time"].iloc[0]) - datetime.timedelta(days=2)
                time_end = pd.Timestamp(df["time"].iloc[-1]) + datetime.timedelta(days=2)
            if file_path is None:
                time_start = date_start
                time_end = date_end
            ax2.set_xlim(time_start, time_end)

            # Grid and labels
            ax2.grid(True, which="both", linewidth=0.5)
            if var_name != 'hf_lgd':
                ax2.set_ylabel("EWH [cm]", fontsize=15)

            # ...
            title_str = f"Time-series for longitude, latitude: {selection} deg."
            # in the case of basin mask, replace for the name
            if isinstance(selection, str):
                title_str = f"Time-series for river basin: {selection}."
            ax2.set_title(title_str, fontsize=15)
            if date_string is not None or len(selection)>2:
                ax2.legend(fontsize=12)
            ############################################################################################################
            # plot heatmap on ax1
            plot_heatmaps_nc_with_map(ds, np.array(labels),
                                      np.array(vmins), np.array(vmaxs), region=region,
                                      cmaps=np.array(cmaps), date_string=date_string, ax=ax1, fig=fig,
                                      add_event=add_feature, masks=mask_plot,
                                      bool_rms=bool_rms, date_start=date_start, date_end=date_end, var_str=var_name)

        # Adjust spacing between plots for better visibility
        plt.tight_layout()

        # Show the plots
        plt.show()

    return df


if __name__ == "__main__":
    # Example usage for .nc file
    var_name = r'ewh_hf_lgd'

    # time-series of a point
    masks = 'Ganges - Bramaputra'
    extract_time_series("03/01/2003", "03/09/2016", var_name, (90, 25),
                        plot=True, date_string="03/08/2007", region=(71, 101, 15, 35), unweighted=False,
                        method='linear', mask_plot=masks, vmin=-35, vmax=35, cmap='Spectral')

    # time-series of a rectangular region
    masks = 'Ganges - Bramaputra'
    extract_time_series("03/01/2003", "03/09/2016", var_name, (89.5, 93, 22, 26),
                        plot=True, date_string="03/08/2007", region=(71, 101, 15, 35), unweighted=False, method='linear',
                        mask_plot=masks, vmin=-35, vmax=35, cmap='Spectral')

    # time-series of a river basin (selection = river_basin_str)
    masks = 'Ganges - Bramaputra'
    extract_time_series("03/01/2003", "03/09/2016", var_name, masks,
                        plot=True, date_string="03/08/2007", region=(71, 101, 15, 35),
                        method='linear', mask_plot=masks, vmin=-35, vmax=35, cmap='Spectral')

    # GRACE RMS
    extract_time_series("01/01/2003", "31/08/2016", var_name, (90.1, 25.1),
                        plot=True, bool_rms=True, region=(-180, 180, -90, 90), method='linear')
