# ============================================================================
# Imports
# ============================================================================

import os
import sys
import glob
import numpy as np
import iris
import warnings
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import datetime
from dateutil.relativedelta import relativedelta
from cartopy.util import add_cyclic_point

# ============================================================================
# Input
# ============================================================================

args = sys.argv
if len(args) < 4:
    print('Usage: python mean_anomaly_plots.py <norcpm version> global|europe <start date>')
    print()
    print('Example: python mean_anomaly_plots.py norcpm-cf-system2 global 20240215')
    print()
    exit()
SYSTEM = args[1]
if SYSTEM == "norcpm-cf-system1":
    SYSTEM_ID = '1'
elif SYSTEM == "norcpm-cf-system2":
    SYSTEM_ID = '2'
else:
    print('No predefined SYSTEM_ID for system ',SYSTEM)
    exit()
REGION = args[2]
START_DATE = args[3]
INITIAL_YEAR = int(START_DATE[0:4])
INITIAL_MONTH = int(START_DATE[4:6])
INITIAL_DAY = int(START_DATE[6:8])
if INITIAL_DAY > 1:
    if INITIAL_MONTH == 12:
        INITIAL_MONTH = 1
        INITIAL_YEAR += 1
    else:
        INITIAL_MONTH += 1

FORECAST_PATH = '/projects/NS9873K/www/norcpm/forecast/monthly/'
THRESHOLD_PATH = '/projects/NS9873K/www/norcpm/forecast/thresholds/'
OUTPUT_PATH = '/projects/NS9873K/www/norcpm/forecast/plots/'
VARIABLE_LIST = ['2m_temperature', '10m_wind_speed', 'mean_sea_level_pressure',
                 'u_component_of_wind_850hPa', 'v_component_of_wind_850hPa',
                 'total_precipitation', 'sea_surface_temperature']

WARNINGS = False

# ============================================================================
# Functions
# ============================================================================

def find_levels(min_value, max_value, increment=1):
    n_levels = ((max_value - min_value + increment) * (1 / increment))

    while (n_levels > 10):
        increment = increment * 2
        n_levels = ((max_value - min_value + 1) * (1 / increment))
    while (n_levels < 6):
        increment = increment / 2
        n_levels = ((max_value - min_value + 1) * (1 / increment))
    
    n_levels = np.ceil(n_levels)
    levels = [min_value + i*increment for i in range(0, int(n_levels + 1))]
    return levels

def callback(cube, field, filename):
    # Add an initialisation date auxilliary coordinate.
    initialisation_date = filename.split('/')[-2].split('_')[-1]
    initialisation_coord = iris.coords.AuxCoord(initialisation_date,
                                                long_name='initialisation')
    cube.add_aux_coord(initialisation_coord)

def plot_mean_anomalies(mean_array, lats, lons, levels,
                        variable, region, units, land_mask=False,
                        title='', save_name=None, cmapvar='viridis'):

    cmap = plt.get_cmap(cmapvar)
    fig, axs = plt.subplots(nrows=1, ncols=1,
                            subplot_kw={'projection': ccrs.PlateCarree()})

    if land_mask:
        axs.add_feature(cfeature.LAND, zorder=2)
    
    # Adding cyclic point removes the white line at longitude=0.
    if region == 'global':
        mean_array, clons = add_cyclic_point(mean_array, coord=lons)
    else:
        clons = lons[:]
    
    contour_plot = axs.contourf(clons, lats, mean_array, levels=levels,
                                   cmap=cmap,
                                   transform=ccrs.PlateCarree(), 
                                   zorder=1, extend='both')
                                   

    fig.colorbar(contour_plot, orientation='horizontal', ax=axs, extend='both', label=units)

    axs.coastlines(resolution='50m')

    fig.set_size_inches(10 ,10)
    axs.set_title(title)
    if save_name:
        plt.savefig(save_name, dpi=300)
        plt.clf()
        plt.cla()
        plt.close()
    else:
        plt.show()

# ============================================================================
# Main
# ============================================================================

variable_dictionary = {'2m_temperature' : ['2 metre temperature', 'K'],
                       '10m_wind_speed' : ['10 metre wind speed', 'm/s'],
                       'mean_sea_level_pressure' : ['Mean sea level pressure', 'hPa'],
                       'u_component_of_wind_850hPa' : ['U velocity at 850hPa', 'm/s'],
                       'v_component_of_wind_850hPa' : ['V velocity at 850hPa', 'm/s'],
                       'sea_surface_temperature' : ['Sea surface temperature', 'K'],
                       'snowfall' : ['Mean total snowfall rate', 'm s**-1'],
                       'total_precipitation' : ['Mean total precipitation rate', 'm s**-1']}

if not WARNINGS:
    warnings.filterwarnings("ignore")


for variable in VARIABLE_LIST:
    print(f"Plotting variable: {variable}...")
    # Get file paths.
    forecast_path = os.path.join(FORECAST_PATH, f'{INITIAL_YEAR}{INITIAL_MONTH:02d}')
    forecast_file = os.path.join(forecast_path, f'{variable}_bccr_{SYSTEM_ID}_{INITIAL_YEAR}_{INITIAL_MONTH:02d}.nc')
    variable_name = variable.replace('_', '-') 
    
    threshold_file = os.path.join(THRESHOLD_PATH,
                                  f'mean_thresholds_1month_{variable_name}_{INITIAL_MONTH:02d}.nc')
    
    issue_date = f'{INITIAL_YEAR}{INITIAL_MONTH:02d}'
    
    # Regional constraints.
    if REGION == 'europe':
        lat_constraint = iris.Constraint(latitude=lambda cell: 34 < cell < 73) 
    else:
        lat_constraint = iris.Constraint(latitude=lambda cell: cell < 999)
        lon_constraint = iris.Constraint(longitude=lambda cell :cell < 999)

    # Read in forecast. 
    forecast_cube = iris.load_cube(forecast_file, 
                                   (variable_dictionary[variable][0] & lat_constraint))
    

    for month in range(0, 6):
        # Read t12 threshold.
        mean_cube = iris.load_cube(threshold_file, ('mean' & lat_constraint))
        
        if REGION == 'europe':
            forecast_cube = forecast_cube.intersection(longitude=(-49.5, 45.15))
            mean_cube = mean_cube.intersection(longitude=(-49.5, 45.15))
                
        # Make land mask for SST. 
        if variable == 'sea_surface_temperature':
          land_mask = True
        else:
          land_mask = False
    
        iforecast = forecast_cube[month, :, :, :]
        
        # Calculate ensemble mean of forecast.
        imforecast = iforecast.collapsed('ensemble_member', iris.analysis.MEAN) 
        
        imean = mean_cube[month]
        
        # Subtract mean climatology from the forecast cube. 
        mean_anomalies = (imforecast - imean).data

        # Plot mean anomalies cube.
        lats = forecast_cube.coord('latitude').points
        lons = forecast_cube.coord('longitude').points
        
        # I have replaced the max and min values by the 10th and 90th percentiles.
        # These values are used for generating the colour bar, and we want a bar that 
        # does not become too elongated by outliers. 
        if variable == 'total_precipitation':
            # Convert m to mm. 
            mean_anomalies = mean_anomalies * 1000 
            p10_value = round(np.percentile(mean_anomalies, 10), 1)
            p90_value = round(np.percentile(mean_anomalies, 90), 1)
            increment = 0.1
        else:
            # Convert Pa to hPa.
            if variable == 'mean_sea_level_pressure':
                mean_anomalies = mean_anomalies / 1000 
            p10_value = np.floor(np.percentile(mean_anomalies, 10))
            p90_value = np.ceil(np.percentile(mean_anomalies, 90))
            increment = 1
        

        levels = find_levels(p10_value, p90_value, increment=increment) 
        
        # Save filename.
        outdir = os.path.join(OUTPUT_PATH, f'{issue_date}')
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        mean_anomalies_file = os.path.join(outdir, f'bccr_{SYSTEM_ID}_LM{month + 1}_{variable_name}_mean-anomalies_{REGION}.png')
         
        # Plot title.
        month_datetime = (datetime.datetime(year=INITIAL_YEAR, month=INITIAL_MONTH, day=1) + relativedelta(months=month + 0 ))
        month_string = month_datetime.strftime('%b')
        title = f'Mean anomaly for {variable_dictionary[variable][0]}. Issued {issue_date}. Month {month_string}. {REGION.capitalize()}.'
        # Get units. 
        units = variable_dictionary[variable][1]

        # Plotting happens here.
        plot_mean_anomalies(mean_anomalies, lats, lons, levels, variable, REGION, units, land_mask=land_mask,
                            title=title, save_name=mean_anomalies_file)

