# ============================================================================
# Imports
# ============================================================================

import os
import sys
import glob
import cftime
import datetime
import dateutil.relativedelta
import netCDF4
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import iris
import warnings

# ============================================================================
# Input
# ============================================================================

args = sys.argv
if len(args) < 3:
    print('Usage: python <norcpm version> iod_plot.py <start date>')
    print()
    print('Example: python <norcpm version> iod_plot.py 20230515')
    print()
    exit()
SYSTEM = args[1]
if SYSTEM == "norcpm-cf-system1":
    SYSTEM_ID = '1'
elif SYSTEM == "norcpm-cf-system2":
    SYSTEM_ID = '2'
else:
    print('No predefined SYSTEM_ID for system ',SYSTEM)
    exit()    
START_DATE = args[2]
START_YEAR = int(START_DATE[0:4])
START_MONTH = int(START_DATE[4:6])
START_DAY = int(START_DATE[6:8])
if START_DAY > 1:
    if START_MONTH == 12:
        START_MONTH = 1
        START_YEAR += 1
    else:
        START_MONTH += 1
FORECAST_DATE = '{:0>4d}{:0>2d}'.format(START_YEAR,START_MONTH)
print(FORECAST_DATE)

ERA5_CLIMATOLOGY_FILE = '/projects/NS9039K/data/external/reanalysis/ECMWF/ERA5/original/monthly_single_level/monthly_averaged_reanalysis/sst/ERA5_034_197901-202112.nc'

NORCPM_CLIMATOLOGY_PATH = '/projects/NS9039K/www/seasonal_forecast/climate_futures/hindcast/monthly/'

WARNINGS = False 
if not WARNINGS:
    warnings.filterwarnings("ignore")
# ============================================================================
# Derived input
# ============================================================================

forecast_year = FORECAST_DATE[0:4]
forecast_month = FORECAST_DATE[4:6]
forecast_date = datetime.datetime(year=int(forecast_year), month=int(forecast_month), day=1)

# Change this to the updated ERA5 file in /projects/NS9039K/data/external/reanalysis/ECMWF/ERA5/original/monthly_single_level/monthly_averaged_reanalysis/sst/.
if forecast_month == '01' :
    recent_month = '12'
    recent_year = f'{int(forecast_year)-1:04d}'
else:
    recent_month = f'{int(forecast_month)-1:02d}'
    recent_year = forecast_year
RECENT_DATE = recent_year + recent_month
ERA5_RECENT_FILE = '/projects/NS9873K/norcpm/validation/reanalysis/ECMWF/ERA5/original/monthly_single_level/monthly_averaged_reanalysis/sst/ERA5_034_202101-' + RECENT_DATE + '.nc'
if not(os.path.isfile(ERA5_RECENT_FILE)):
    if recent_month == 1:
        recent_year = f'{int(forecast_year)-2:04d}'
        recent_month = '12'
    else:
        if recent_month == '12':
            recent_month == '11'
        else:
            recent_month = f'{int(forecast_month)-2:02d}'
    RECENT_DATE = recent_year + recent_month
    ERA5_RECENT_FILE = '/projects/NS9873K/norcpm/validation/reanalysis/ECMWF/ERA5/original/monthly_single_level/monthly_averaged_reanalysis/sst/ERA5_034_202101-' + RECENT_DATE + '.nc'

FORECAST_PATH = '/projects/NS9873K/www/norcpm/forecast/monthly/'
FORECAST_FILE = FORECAST_PATH + f'{FORECAST_DATE}/sea_surface_temperature_bccr_{SYSTEM_ID}_{forecast_year}_{forecast_month}.nc'

OUTPUT_DIR = f'/projects/NS9873K/www/norcpm/forecast/plots/{FORECAST_DATE}'
OUTPUT_FILE = OUTPUT_DIR + f'/bccr_{SYSTEM_ID}_iod.png'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================================
# Functions
# ============================================================================

def cftime_to_datetime(cftime):
    return(datetime.datetime(cftime.year, cftime.month, cftime.day))

def number_to_month(number):
    dt = datetime.datetime.strptime(str(number), "%m")
    return(dt.strftime("%b"))

def calculate_month_climatology(in_array, month):
    return(np.mean(in_array[(month - 1)::12]))

def calculate_iod(in_cube):

    if np.amin(in_cube.coord('longitude').points) < -10:
        longitude_system = 180
    else:
        longitude_system = 360

    lat_constraint1 = iris.Constraint(latitude=lambda cell: -10 <= cell <= 10)
    lon_constraint1 = iris.Constraint(longitude=lambda cell: 50 <= cell <= 70)
    constrained_cube1 = in_cube.extract(lat_constraint1 & lon_constraint1)
  
    lat_constraint2 = iris.Constraint(latitude=lambda cell: -10 <= cell <= 0)
    lon_constraint2 = iris.Constraint(longitude=lambda cell: 90 <= cell <= 110)
    constrained_cube2 = in_cube.extract(lat_constraint2 & lon_constraint2)
    
    # The ERA5 data has cells positioned such that we want to weight the outer cells as half. 
    if longitude_system == 360:
        out_array1 = weighted_average_cube(constrained_cube1)
        out_array2 = weighted_average_cube(constrained_cube2)
    else:
        out_array1 = np.mean(constrained_cube1.data, axis=(1, 2))
        out_array2 = np.mean(constrained_cube2.data, axis=(1, 2))

    out_array = out_array1 - out_array2
    return(out_array)

def weighted_average_cube(constrained_cube):
    area_shape = (constrained_cube.coord('latitude').shape[0], constrained_cube.coord('longitude').shape[0])
    weights_array = np.ones(area_shape)
    weights_array[0, :] = weights_array[0, :]/2
    weights_array[-1, :] = weights_array[-1, :]/2
    weights_array[:, 0] = weights_array[:, 0]/2
    weights_array[:, -1] = weights_array[:, -1]/2
    weights_average = np.mean(weights_array)
    out_array = np.mean(constrained_cube.data * weights_array / weights_average, axis=(1, 2))
    return(out_array)


# ============================================================================
# Main
# ============================================================================

# Calculate the NorCPM climatology for the forecast month (using hindcasts).

file_list_raw = glob.glob(os.path.join(NORCPM_CLIMATOLOGY_PATH, f'*/sea_surface*'))


# Exclude certain dates so that we have climatology 199301-202012.
exclusion_list = ['199210', '199211', '199212', '202101', '202102', '202103', '202104', '202104', '202105', '202106', '202107', '202108', '202109', '202110', '202111', '202112', '202201', '202202', '202203', '202204', '202204', '202205', '202206', '202207', '202208', '202209', '202210', '202211', '202212', '202301', '202302', '202303', '202304', '202304', '202305', '202306', '202307', '202308', '202309', '202310', '202311', '202312', '202401', '202402', '202403', '202404', '202404', '202405', '202406', '202407', '202408', '202409', '202410', '202411', '202412']
file_list = []
for ifile in file_list_raw:
    if not any([x in ifile for x in exclusion_list]):
        file_list.append(ifile)

norcpm_climatology_array = np.zeros((len(file_list), 6))
for i, ifile in enumerate(file_list):
    icube = iris.load_cube(ifile, 'Sea surface temperature')
    icube_ensemble_mean = icube.collapsed('ensemble_member', iris.analysis.MEAN)
    iiod = calculate_iod(icube_ensemble_mean)
    norcpm_climatology_array[i, :] = iiod.data
norcpm_climatology = np.mean(norcpm_climatology_array, axis=0)

# Load in iris cube. 
era5_climatology_raw = iris.load_cube(ERA5_CLIMATOLOGY_FILE)

# Hard indexing 198001-201012 because cfime utime is not working for iris constraints.
era5_climatology = era5_climatology_raw[24:372]

# Calculate IOD index in (-5, 5)N (190, 240)E. 
iod_climatology = calculate_iod(era5_climatology)

# Load in recent months ERA5 (for past few indices). 
era5_new_raw = iris.load_cube(ERA5_RECENT_FILE)
era5_new_origin = era5_new_raw.coord('time').units.origin
era5_new_origin_split = era5_new_origin.split(' ')[2].split('-')
era5_new_origin_datetime = datetime.datetime(year=int(era5_new_origin_split[0]), month=int(era5_new_origin_split[1]), day=int(era5_new_origin_split[2]))

era5_datetime_list = [(forecast_date - dateutil.relativedelta.relativedelta(months=i)) for i in range(1, 6)]

time_constraint = iris.Constraint(time=lambda cell: cell.point in era5_datetime_list)
era5_new = era5_new_raw.extract(time_constraint)

iod_new = calculate_iod(era5_new)

# Dates of last months.
era5_date_list = netCDF4.num2date(era5_new.coord('time').points, units=era5_new_origin)

# Calculate anomaly values for each month.
era5_plot_list = []
for i, idate in enumerate(era5_date_list):
    point = iod_new[i] - np.mean(iod_climatology[(idate.month - 1)::12])
    era5_plot_list.append(point)

# Load in the NorCPM forecast data.
forecast_raw = iris.load_cube(FORECAST_FILE, 'Sea surface temperature')
forecast_ensemble_mean = forecast_raw.collapsed('ensemble_member', iris.analysis.MEAN)
forecast_iod = calculate_iod(forecast_ensemble_mean)

# Dates of forecasts:
initial_time = forecast_raw.attributes['initial_time']
initial_day = initial_time.split('/')[0]
initial_month = initial_time.split('/')[1]
initial_year = initial_time.split('/')[2].split(' ')[0]
initial_date = f'{initial_year}-{initial_month}-{initial_day} 00:00'

# Shift time points by one month, because defined as end of month and we want start month plotted. 
time_points = [x for x in forecast_raw.coord('time').points]
time_points = [x - time_points[0] for x in time_points]

forecast_date_list = netCDF4.num2date(time_points, units=f'hours since {initial_date}')

# Calculate anomaly values for each month. 
forecast_plot_list = []
for i in range(len(forecast_date_list)):
    point = forecast_iod[i] - norcpm_climatology[i]
    forecast_plot_list.append(point)

x_values = list(era5_date_list) + list(forecast_date_list)
x_values = [cftime_to_datetime(x) for x in x_values]

era5_date_list = [cftime_to_datetime(x) for x in era5_date_list]
forecast_date_list = [cftime_to_datetime(x) for x in forecast_date_list]

fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(111)

# Loop over and add ensemble members.
for j in range(60):
    forecast_iod_j = calculate_iod(forecast_raw[:, j])
    forecast_plot_list_j = []
    for i in range(len(forecast_date_list)):
        point = forecast_iod_j[i] - norcpm_climatology[i]
        forecast_plot_list_j.append(point)

    y_values_j = era5_plot_list + forecast_plot_list_j
    plt.plot(x_values, y_values_j, '-o', color='red')

y_values = era5_plot_list + forecast_plot_list

plt.xticks(x_values)
plt.title('NorCPM seasonal DMI (IOD) forecast')
plt.xlabel('Date')
plt.ylabel('Temperature anomaly (K)')
plt.plot(x_values, y_values, '-o', color='blue')
plt.plot(list(era5_date_list), era5_plot_list, '-o', color='black')
plt.plot(x_values, [0.4]*len(y_values), '--', color='blue')
plt.plot(x_values, [0]*len(y_values), '--', color='grey')
plt.plot(x_values, [-0.4]*len(y_values), '--', color='blue')
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
plt.savefig(OUTPUT_FILE)
#plt.show()

