# ============================================================================
# Imports
# ============================================================================

import os
import sys
import glob
import numpy as np
import iris
import warnings
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import datetime
from dateutil.relativedelta import relativedelta
from cartopy.util import add_cyclic_point
import shapely

# ============================================================================
# Input
# ============================================================================

args = sys.argv
if len(args) < 3:
    print('Usage: python quantile_plots.py <norcpm version> global|europe <start date>')
    print()
    print('Example: python quantile_plots.py norcpm-cf-system2 global 20240215')
    print()
    exit()
SYSTEM = args[1]
if SYSTEM == "norcpm-cf-system1":
    SYSTEM_ID = '1'
elif SYSTEM == "norcpm-cf-system2":
    SYSTEM_ID = '2'
else:
    print('No predefined SYSTEM_ID for system ',SYSTEM)
    exit()
REGION = args[2]
START_DATE = args[3]
INITIAL_YEAR = int(START_DATE[0:4])
INITIAL_MONTH = int(START_DATE[4:6])
INITIAL_DAY = int(START_DATE[6:8])
if INITIAL_DAY > 1:
    if INITIAL_MONTH == 12:
        INITIAL_MONTH = 1
        INITIAL_YEAR += 1
    else:
        INITIAL_MONTH += 1

FORECAST_PATH = '/projects/NS9873K/www/norcpm/forecast/monthly/'
THRESHOLD_PATH = '/projects/NS9873K/www/norcpm/forecast/thresholds/'
OUTPUT_PATH = '/projects/NS9873K/www/norcpm/forecast/plots/'
VARIABLE_LIST = ['2m_temperature', '10m_wind_speed', 'mean_sea_level_pressure',
                 'u_component_of_wind_850hPa', 'v_component_of_wind_850hPa',
                 'total_precipitation', 'sea_surface_temperature']

WARNINGS = False

# ============================================================================
# Functions
# ============================================================================

def callback(cube, field, filename):
    # Add an initialisation date auxilliary coordinate.
    initialisation_date = filename.split('/')[-2].split('_')[-1]
    initialisation_coord = iris.coords.AuxCoord(initialisation_date,
                                                long_name='initialisation')
    cube.add_aux_coord(initialisation_coord)

def plot_probability(field, lats, lons, levels, save_name=None,
                     cmapvar='viridis', title=''):
    cmap = plt.get_cmap(cmapvar)
    ax = plt.axes(projection=ccrs.PlateCarree())
    contour_plot = ax.contourf(lons, lats, field, levels=levels,
                               cmap=cmap,
                               transform=ccrs.PlateCarree())

    ax.coastlines(resolution='50m')
    fig = plt.gcf()
    fig.set_size_inches(10 ,10)
    fig.colorbar(contour_plot, orientation='vertical')
    plt.title(title)
    if save_name:
        plt.savefig(save_name, dpi=300)
        plt.clf()
        plt.cla()
    else:
        plt.show()

def plot_probability_terciles(above_array, normal_array, below_array, 
                              lats, lons, levels, variable, region, land_mask=False,
                              title='',
                              save_name=None, cmapvar='viridis'):
    # Get titles.
    above_title = f"Probability of upper tercile {variable.replace('_', ' ')}"
    normal_title = f"Probability of middle tercile {variable.replace('_', ' ')}"
    below_title = f"Probability of lower tercile {variable.replace('_', ' ')}"

    cmap = plt.get_cmap(cmapvar)
    fig, axs = plt.subplots(nrows=3, ncols=1,
                            subplot_kw={'projection': ccrs.PlateCarree()})

    axs = axs.flatten()
    
    if land_mask:
        axs[0].add_feature(cfeature.LAND, zorder=2)
        axs[1].add_feature(cfeature.LAND, zorder=2)
        axs[2].add_feature(cfeature.LAND, zorder=2)

    
    # Adding cyclic point removes the white line at longitude=0.
    if region == 'global':
        above_array, clons = add_cyclic_point(above_array, coord=lons)
        normal_array, clons = add_cyclic_point(normal_array, coord=lons)
        below_array, clons = add_cyclic_point(below_array, coord=lons)
    else:
        clons = lons[:]
    
    contour_plot = axs[0].contourf(clons, lats, above_array, levels=levels,
                                   cmap=cmap,
                                   transform=ccrs.PlateCarree(), 
                                   zorder=1)
    contour_plot = axs[1].contourf(clons, lats, normal_array, levels=levels,
                                   cmap=cmap,
                                   transform=ccrs.PlateCarree(),
                                   zorder=1)
    contour_plot = axs[2].contourf(clons, lats, below_array, levels=levels,
                                   cmap=cmap,
                                   transform=ccrs.PlateCarree(),
                                   zorder=1)

    fig.colorbar(contour_plot, orientation='vertical', ax=axs[0])
    fig.colorbar(contour_plot, orientation='vertical', ax=axs[1])
    fig.colorbar(contour_plot, orientation='vertical', ax=axs[2])
    
    axs[0].coastlines(resolution='50m')
    axs[1].coastlines(resolution='50m')
    axs[2].coastlines(resolution='50m')

    axs[0].set_title(above_title)
    axs[1].set_title(normal_title)
    axs[2].set_title(below_title)

    fig.set_size_inches(10 ,10)
    fig.suptitle(title)
    if save_name:
        plt.savefig(save_name, dpi=300)
        plt.clf()
        plt.cla()
        plt.close()
    else:
        plt.show()

def plot_probability_quintiles(above_array_in, below_array_in,
                              lats, lons, levels, variable, region, land_mask=False,
                              title='',
                              save_name=None, cmapvar='viridis'):
    # Get titles.
    above_title = f"Probability of upper quintile {variable.replace('_', ' ')}"
    below_title = f"Probability of lower quintile {variable.replace('_', ' ')}"

    if region == 'global':
        # Adding cyclic point removes the white line at longitude=0.
        above_array, clons = add_cyclic_point(above_array_in, coord=lons)
        below_array, clons = add_cyclic_point(below_array_in, coord=lons)
    else:
        clons = lons[:]
        above_array = above_array_in
        below_array = below_array_in

    cmap = plt.get_cmap(cmapvar)
    fig, axs = plt.subplots(nrows=2, ncols=1,
                            subplot_kw={'projection': ccrs.PlateCarree()})

    axs = axs.flatten()

    # Try and except because sometimes plot fails with cyclic point.  
    try:
        contour_plot = axs[0].contourf(clons, lats, above_array, levels=levels,
                                       cmap=cmap,
                                       transform=ccrs.PlateCarree(), 
                                       zorder=1)
        contour_plot = axs[1].contourf(clons, lats, below_array, levels=levels,
                                       cmap=cmap,
                                       transform=ccrs.PlateCarree(),
                                       zorder=1)
    except:
        fig, axs = plt.subplots(nrows=2, ncols=1,
                                subplot_kw={'projection': ccrs.PlateCarree()})
        axs = axs.flatten()
        contour_plot = axs[0].contourf(lons[:], lats, above_array_in, levels=levels,
                                       cmap=cmap,
                                       transform=ccrs.PlateCarree(),
                                       zorder=1)
        contour_plot = axs[1].contourf(lons[:], lats, below_array_in, levels=levels,
                                       cmap=cmap,
                                       transform=ccrs.PlateCarree(),
                                       zorder=1)

    fig.colorbar(contour_plot, orientation='vertical', ax=axs[0])
    fig.colorbar(contour_plot, orientation='vertical', ax=axs[1])
    
    axs[0].coastlines(resolution='50m')
    axs[1].coastlines(resolution='50m')

    if land_mask:
        axs[0].add_feature(cfeature.LAND, zorder=2)
        axs[1].add_feature(cfeature.LAND, zorder=2)

    axs[0].set_title(above_title)
    axs[1].set_title(below_title)

    fig.set_size_inches(10 ,10)
    fig.suptitle(title)
    if save_name:
        plt.savefig(save_name, dpi=300)
        plt.clf()
        plt.cla()
        plt.close()
    else:
        plt.show()


# ============================================================================
# Main
# ============================================================================

variable_dictionary = {'2m_temperature' : '2 metre temperature',
                       '10m_wind_speed' : '10 metre wind speed',
                       'mean_sea_level_pressure' : 'Mean sea level pressure',
                       'u_component_of_wind_850hPa' : 'U velocity at 850hPa',
                       'v_component_of_wind_850hPa' : 'V velocity at 850hPa',
                       'sea_surface_temperature' : 'Sea surface temperature',
                       'snowfall' : 'Mean total snowfall rate',
                       'total_precipitation' : 'Mean total precipitation rate'}

if not WARNINGS:
    warnings.filterwarnings("ignore")


for variable in VARIABLE_LIST:
    print(f"Plotting variable: {variable}...")
    # Get file paths.
    forecast_path = os.path.join(FORECAST_PATH, f'{INITIAL_YEAR}{INITIAL_MONTH:02d}')
    forecast_file = os.path.join(forecast_path, f'{variable}_bccr_{SYSTEM_ID}_{INITIAL_YEAR}_{INITIAL_MONTH:02d}.nc')
    variable_name = variable.replace('_', '-') 
    
    threshold_file = os.path.join(THRESHOLD_PATH,
                                  f'thresholds_3month_{variable_name}_{INITIAL_MONTH:02d}.nc')
    
    issue_date = f'{INITIAL_YEAR}{INITIAL_MONTH:02d}'
    
    # Regional constraints.
    if REGION == 'europe':
        lat_constraint = iris.Constraint(latitude=lambda cell: 34 < cell < 73) 
    else:
        lat_constraint = iris.Constraint(latitude=lambda cell: cell < 999)
        lon_constraint = iris.Constraint(longitude=lambda cell :cell < 999)

    # Read in forecast. 
    forecast_cube = iris.load_cube(forecast_file, 
                                   (variable_dictionary[variable] & lat_constraint))
    

    for month_group in range(0, 4):
        # Read t12 threshold.
        t12_cube = iris.load_cube(threshold_file, ('t12' & lat_constraint))
        # Read t23 threshold.
        t23_cube = iris.load_cube(threshold_file, ('t23' & lat_constraint))
        # Read q12 threshold.
        q12_cube = iris.load_cube(threshold_file, ('q12' & lat_constraint))
        # Read q45 threshold.
        q45_cube = iris.load_cube(threshold_file, ('q45' & lat_constraint))
        
        if REGION == 'europe':
            forecast_cube = forecast_cube.intersection(longitude=(-49.5, 45.15))
            t12_cube = t12_cube.intersection(longitude=(-49.5, 45.15))
            t23_cube = t23_cube.intersection(longitude=(-49.5, 45.15))
            q12_cube = q12_cube.intersection(longitude=(-49.5, 45.15))
            q45_cube = q45_cube.intersection(longitude=(-49.5, 45.15))
                
        # Make land mask for SST. 
        if variable == 'sea_surface_temperature':
          land_mask = True
        else:
          land_mask = False
    
        iforecast = forecast_cube[month_group : (month_group + 3), :, :, :]
        imforecast = iforecast.collapsed('time', iris.analysis.MEAN)


        it12 = t12_cube[month_group]
        it23 = t23_cube[month_group]
        iq12 = q12_cube[month_group]
        iq45 = q45_cube[month_group]

        t3_array = sum(imforecast.data > it23.data)
        t2_array = sum((imforecast.data >= it12.data) & (imforecast.data <= it23.data))
        t1_array = sum(imforecast.data < it12.data)
        q1_array = sum(imforecast.data < iq12.data)
        q5_array = sum(imforecast.data > iq45.data)

        # Check that arrays sum to total of 60 members for all cells.
        if not ((t1_array + t2_array + t3_array) == 60).all():
            raise ValueError('Incorrect division of ensemble members into threshold categories.')


        # Calculate probabilities.
        t3_array = t3_array.data / 60
        t2_array = t2_array.data / 60
        t1_array = t1_array.data / 60
        q1_array = q1_array.data / 60
        q5_array = q5_array.data / 60

        # Plot probabilities.
        lats = forecast_cube.coord('latitude').points
        lons = forecast_cube.coord('longitude').points
        levels = np.linspace(0, 1, 6)

        # Save filename.
        outdir = os.path.join(OUTPUT_PATH, f'{issue_date}')
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        tercile_file = os.path.join(outdir, f'bccr_{SYSTEM_ID}_LM{month_group+1:02d}-{month_group+3:02d}_{variable_name}_3month-terciles_{REGION}.png')
        quintile_file = os.path.join(outdir, f'bccr_{SYSTEM_ID}_LM{month_group+1:02d}-{month_group+3:02d}_{variable_name}_3month-quintiles_{REGION}.png')

        # Plot title.
        months = [(datetime.datetime(year=INITIAL_YEAR, month=INITIAL_MONTH, day=1) + relativedelta(months=month_group + x)) for x in range(1, 4)]
        month_strings = [month.strftime('%b') for month in months]
        month_string = '/'.join(month_strings)
        
        tercile_title = f'Tercile categories. Issued {issue_date}. Period {month_string}. {REGION.capitalize()}.'
        quintile_title = f'Outer quintiles. Issued {issue_date}. Period {month_string}. {REGION.capitalize()}.'
 
        # Plotting happens here.
        plot_probability_terciles(t3_array, t2_array, t1_array,
                                  lats, lons, levels, variable, REGION, land_mask=land_mask,
                                  title=tercile_title,
                                  save_name=tercile_file)
        plot_probability_quintiles(q5_array, q1_array,
                                  lats, lons, levels, variable, REGION, land_mask=land_mask,
                                  title=quintile_title,
                                  save_name=quintile_file)

