# ============================================================================
# Imports
# ============================================================================

import os
import glob
import iris
import warnings

# ============================================================================
# Input
# ============================================================================

HINDCAST_DIRECTORY = '/projects/NS9039K/www/seasonal_forecast/climate_futures/hindcast/monthly/'
#OUTPUT_PATH = '/projects/NS9039K/projects/climate_futures/seasonal_forecast/thresholds/'
OUTPUT_PATH = '/projects/NS9873K/www/norcpm/forecast/thresholds/'
WARNINGS = False
VARIABLE_LIST = ['2m_temperature', '10m_wind_speed', 'mean_sea_level_pressure',
                 'u_component_of_wind_850hPa', 'v_component_of_wind_850hPa', 
                 'sea_surface_temperature', 'total_precipitation']

# ============================================================================
# Functions
# ============================================================================

def callback(cube, field, filename):
    # Add an initialisation date auxilliary coordinate.
    initialisation_date = filename.split('/')[-2].split('_')[-1]
    initialisation_coord = iris.coords.AuxCoord(initialisation_date,
                                                long_name='initialisation')
    cube.add_aux_coord(initialisation_coord)

# ============================================================================
# Main
# ============================================================================

variable_dictionary = {'2m_temperature' : '2 metre temperature', 
                       '10m_wind_speed' : '10 metre wind speed',
                       'mean_sea_level_pressure' : 'Mean sea level pressure',
                       'u_component_of_wind_850hPa' : 'U velocity at 850hPa',
                       'v_component_of_wind_850hPa' : 'V velocity at 850hPa',
                       'sea_surface_temperature' : 'Sea surface temperature',
                       'snowfall' : 'Mean total snowfall rate',
                       'total_precipitation' : 'Mean total precipitation rate'}

if not WARNINGS:
    warnings.filterwarnings("ignore")

for variable in VARIABLE_LIST:
    for lead_month in [f'{x:02d}' for x in range(1, 13)]:
        print(f'Calculating thresholds for {variable} {lead_month}')
        file_list = glob.glob(os.path.join(HINDCAST_DIRECTORY, f'????{lead_month}/{variable}*.nc'))
        exclude_list = ['19921015', '19921115', '19921215', '20210115', '20210215', '20210315']
        file_list = [ifile for ifile in file_list if ifile.split('/')[-2].split('_')[-1] not in exclude_list]

        cube_list = iris.load(file_list, variable_dictionary[variable], callback=callback)
        iris.util.equalise_attributes(cube_list)


        for cube in cube_list:
            # Change the forecast offset variable to be in units of months.
            # Hours introduces a difference between leap and non-leap years.
            # This makes it awkward to combine the data (using iris).
            N = len(cube.coord('time').points)
            cube.coord('time').points = [i for i in range(1, N + 1)]

        cube = cube_list.merge_cube()

        mean_list = []
        t12_list = []
        t23_list = []
        q12_list = []
        q23_list = []
        q34_list = []
        q45_list = []
        median_list = []
        # Loop over month_groups [(1, 2, 3), (2, 3, 4), ...].
        for month_group in range(0, 4):
            icube = cube[:, month_group:(month_group + 3), :, :, :]
            # Calculate the mean along the lead month axis. (To get seasonal means). 
            imcube = icube.collapsed('time', iris.analysis.MEAN)

            # When .collapsed() is used here, it saves information about the 'Forecast offset...' variable 
            # such that when cubes are later combined into one cube, they are in the correct order. 

            # Calculate median.
            imedian = imcube.collapsed(('initialisation', 'ensemble_member'), iris.analysis.MEDIAN)
            median_list.append(imedian) 

            # Calculate mean.
            imean = imcube.collapsed(('initialisation', 'ensemble_member'), iris.analysis.MEAN)
            mean_list.append(imean)

            # Calculate percentiles for tercile and quintile bounds
            ipercentiles_cube = imcube.collapsed(('initialisation', 'ensemble_member'), iris.analysis.PERCENTILE, percent=[100/5, 100/3, 200/5, 300/5,  200/3, 400/5])

            # Tercile bound t12 
            it12 = ipercentiles_cube[1]
            t12_list.append(it12) 

            it23 = ipercentiles_cube[4]
            t23_list.append(it23) 

            iq12 = ipercentiles_cube[0]
            q12_list.append(iq12) 

            iq23 = ipercentiles_cube[2]
            q23_list.append(iq23) 

            iq34 = ipercentiles_cube[3]
            q34_list.append(iq34) 

            iq45 = ipercentiles_cube[5]
            q45_list.append(iq45) 

        median = iris.cube.CubeList(median_list).merge_cube() 
        median.rename('median') 
        mean = iris.cube.CubeList(mean_list).merge_cube() 
        mean.rename('mean')
        t12 = iris.cube.CubeList(t12_list).merge_cube()    
        t12.rename('t12')
        t23 = iris.cube.CubeList(t23_list).merge_cube()    
        t23.rename('t23')
        q12 = iris.cube.CubeList(q12_list).merge_cube()    
        q12.rename('q12')
        q23 = iris.cube.CubeList(q23_list).merge_cube()    
        q23.rename('q23')
        q34 = iris.cube.CubeList(q34_list).merge_cube()    
        q34.rename('q34')
        q45 = iris.cube.CubeList(q45_list).merge_cube()    
        q45.rename('q45')

        # Combine all these variables into one cube.CubeList.
        output_data = iris.cube.CubeList([mean, median, t12, t23, q12, q23, q34, q45])

        # Save the output to one file per lead time per variable.
        variable_name = variable.replace('_', '-') 
        output_file = os.path.join(OUTPUT_PATH, f'thresholds_3month_{variable_name}_{lead_month}.nc')
        print(f'Saving thresholds to file: {output_file}...')
        iris.save(output_data, output_file)

print('Threshold generation complete.')
