# ============================================================================
# Imports
# ============================================================================

import os
import glob
import iris
import warnings

# ============================================================================
# Input
# ============================================================================

HINDCAST_DIRECTORY = '/projects/NS9039K/www/seasonal_forecast/climate_futures/hindcast/monthly/'
OUTPUT_PATH = '/projects/NS9873K/www/norcpm/forecast/thresholds/'
WARNINGS = False
VARIABLE_LIST = ['2m_temperature', '10m_wind_speed', 'mean_sea_level_pressure',
                 'u_component_of_wind_850hPa', 'v_component_of_wind_850hPa', 
                 'sea_surface_temperature', 'total_precipitation']

# ============================================================================
# Functions
# ============================================================================

def callback(cube, field, filename):
    # Add an initialisation date auxilliary coordinate.
    initialisation_date = filename.split('/')[-2].split('_')[-1]
    initialisation_coord = iris.coords.AuxCoord(initialisation_date,
                                                long_name='initialisation')
    cube.add_aux_coord(initialisation_coord)

# ============================================================================
# Main
# ============================================================================

variable_dictionary = {'2m_temperature' : '2 metre temperature', 
                       '10m_wind_speed' : '10 metre wind speed',
                       'mean_sea_level_pressure' : 'Mean sea level pressure',
                       'u_component_of_wind_850hPa' : 'U velocity at 850hPa',
                       'v_component_of_wind_850hPa' : 'V velocity at 850hPa',
                       'sea_surface_temperature' : 'Sea surface temperature',
                       'snowfall' : 'Mean total snowfall rate',
                       'total_precipitation' : 'Mean total precipitation rate'}

if not WARNINGS:
    warnings.filterwarnings("ignore")

for variable in VARIABLE_LIST:
    for lead_month in [f'{x:02d}' for x in range(1, 13)]:
        print(f'Calculating thresholds for {variable} {lead_month}')
        file_list = glob.glob(os.path.join(HINDCAST_DIRECTORY, f'????{lead_month}/{variable}*.nc'))
        exclude_list = ['19921015', '19921115', '19921215', '20210115', '20210215', '20210315']
        file_list = [ifile for ifile in file_list if ifile.split('/')[-2].split('_')[-1] not in exclude_list]
        print(file_list)
        cube_list = iris.load(file_list, variable_dictionary[variable], callback=callback)
        iris.util.equalise_attributes(cube_list)


        for cube in cube_list:
            # Change the forecast offset variable to be in units of months.
            # Hours introduces a difference between leap and non-leap years.
            # This makes it awkward to combine the data (using iris).
            N = len(cube.coord('time').points)
            cube.coord('time').points = [i for i in range(1, N + 1)]

        cube = cube_list.merge_cube()

        mean_list = []
        t12_list = []
        t23_list = []
        q12_list = []
        q23_list = []
        q34_list = []
        q45_list = []
        median_list = []
        # Loop over months. 
        for month in range(0, 6):
            icube = cube[:, month,  :, :, :]

            # Calculate mean.
            imean = icube.collapsed(('initialisation', 'ensemble_member'), iris.analysis.MEAN)
            mean_list.append(imean)

        mean = iris.cube.CubeList(mean_list).merge_cube() 
        mean.rename('mean')

        # Combine all these variables into one cube.CubeList.
        output_data = iris.cube.CubeList([mean])

        # Save the output to one file per lead time per variable.
        variable_name = variable.replace('_', '-') 
        output_file = os.path.join(OUTPUT_PATH, f'mean_thresholds_1month_{variable_name}_{lead_month}.nc')
        print(f'Saving thresholds to file: {output_file}...')
        iris.save(output_data, output_file)

print('Threshold generation complete.')
