# Calculate seasonal probability of 2m temperature (three-month average)
# Compare forecast with climatology
# Plot probability of 2m temperature
# Period for climatology: 1993-2020 -> Copernicus: 1993-2016 #ver1.1.0
#
#
# Script version 2.0.0
#
# Script works with:
#   Python version: 3.10.13
#
#   Package version
#     numpy:  1.26.3
#     pandas:  2.2.0
#     netCDF4:  1.6.4
#     scipy:  1.12.0
#
# Ver1.0.0: Created by Mariko Koseki, 02.04.2024
# Ver1.0.1: updated by Mariko, 13.05.2024
# Ver1.1.0: updated by Mariko, 15.05.2024
# Ver2.0.0: Updated by Mariko, 12.06.2024


import netCDF4
import numpy as np
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
import os
import sys
import platform
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import scipy
from scipy.interpolate import RegularGridInterpolator as RGI


print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('netCDF4: ', netCDF4.__version__)
print('scipy: ', scipy.__version__)




##--- Function -----------------------------------------##
def box_smooth_2D(field,n_ih,n_jh,lat_wgt=True,latitude=None):

    ## The function is quoted from:
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/tree/main/cf_monthly_forecast
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/blob/main/notebooks/002_forecast_plots.py

    """
    Smooth a 2d field w/ dimensions (i,j) by applying a simple box filter (2d-equivalent to a running-mean)
    assumes cyclic j-coordinate (e.g., longitude for global data)
    note: n_ih = 1, n_jh = 1 will create averages of the central points and the 8 points around each
    INPUT:
        field       : (numpy.array) with 2 dimensions (i,j)
        n_ih        : (int) number of grid points to take into account in positive & negative i-direction of the central point
        n_jh        : (int) number of grid points to take into account in positive & negative j-direction of the central point
        lat_wgt     : (boolean) whether to apply latitude weights in i-direction or not; if True, need to define latitude
        latitude    : (numpy.array) of length field.shape[0] (length of i-dimension) with latitudes used for weighting
    """

    if lat_wgt:
        assert latitude is not None, 'need to specify latitude vector of length N_i to apply latitude weighting'

    N_i,N_j = field.shape

    # pad the field in j direction (longitude):
    field_padded = np.concatenate([field[:,N_j-n_jh:],field,field[:,:n_jh]],axis=1)

    N_jp = field_padded.shape[1]

    box_kernel = np.ones([n_ih*2+1,n_jh*2+1]) # un-normalized kernel
    # use a copy of the original and overwrite, note that this means that edge points will keep un-smoothed values!
    smooth_field = field.copy()
    for i in range(N_i):
        for j in range(N_jp):
            if (i > n_ih) and (j > n_jh) and (i < N_i-n_ih-1) and (j < N_jp-n_jh-1):
                if lat_wgt:
                    lat_weight = np.cos(latitude[i-n_ih:i+n_ih+1]/180*np.pi)
                    wgts = lat_weight * box_kernel
                else:
                    wgts = box_kernel
                smooth_field[i,j-n_jh] = (wgts*field_padded[i-n_ih:i+n_ih+1,j-n_jh:j+n_jh+1]).sum()/wgts.sum()
    return smooth_field


##--- Set year, month for climatology ------------------------------------## #ver1.1.0
'''
Set 'start_year_clm' and 'end_year_clm'
Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016



##--- Set path ------------------------------------------##
## Set path to input files ##
'''
Set '' (path to input NetCDF file)
'''
norcpm_monthly = '/nird/projects/NS9873K/www/norcpm/forecast/monthly/'
norcpm_clim = '/nird/projects/NS9873K/norcpm/validation/2m_temperature/median/clim/'

## Set path to output files ##
'''
Set 'out_fig_dir' (path to output file)
'''
#out_fig_dir  = '/nird/projects/NS9873K/norcpm/validation/2m_temperature/median/fig/smooth/clm_' + str(start_year_clm) + '_' + str(end_year_clm) + '/' #Ver1.1.0 # <--- tmp forlder
#out_fig_dir  = '/nird/projects/NS9873K/norcpm/validation/2m_temperature/median/fig/'
out_fig_dir  = '/nird/projects/NS9873K/www/norcpm/forecast/plots/version2/' #ver2.0.0 <--- folder for archiving figures
if not os.path.exists(out_fig_dir):
    os.makedirs(out_fig_dir)


## Select NorCPM system ##
'''
Set system number 'system_num'
'''
system_num = '1'
#system_num = '2'

print('')
print('system: ', system_num)
print('')

##--- Input ----------------------------------------------##
args = sys.argv
if len(args) == 2:
    input_date = args[1]
    if len(input_date) == 6:
        yyyy = int(input_date[0:4])
        mm = int(input_date[4:6])
        dd = 1 #ver1.0.1

    else:
        print('')
        print('---How to use this script---')
        print('python plot_seasonal_probability_2mt.py <yyyymm>')
        print('<yyyymm> = Forecast start date')
        print('Example: python python plot_seasonal_probability_2mt.py 202403')
        print('')
        sys.exit()

else:
    print('')
    print('---How to use this script---')
    print('python python plot_seasonal_probability_2mt.py <yyyymm>')
    print('<yyyymm> = Forecast start date')
    print('Example: python plot_seasonal_probability_2mt.py 202403')
    print('')
    sys.exit()


start_year = '{:0>4d}'.format(yyyy)
start_month = '{:0>2d}'.format(mm)
start_date = '{:0>2d}'.format(dd) #ver1.0.1

issue_date = start_year + start_month + start_date #ver2.0.0
print('')
print('Issued: ', issue_date)
print('')

forecast_start_date = '{:0>2d}/{:0>2d}/{:0>4d}'.format(dd,mm,yyyy) #ver1.0.1
print('Forecast start: ', forecast_start_date)
print('')

mon = mm - 1



##--- Read netCDF files ----------------------------------##
## Open NetCDF files ##
### Forecast ###
nc_forecast = norcpm_monthly + start_year + start_month + '/2m_temperature_bccr_' \
        + system_num + '_' + start_year + '_' + start_month + '.nc'
print('forecast file name: ', nc_forecast)
nc = netCDF4.Dataset(nc_forecast, 'r')


### Climatology ###
#nc_clim_2m_temp = norcpm_clim + 'bccr_' + system_num + '_2m_temperature_clim_median_1993_2020.nc'
nc_clim_2m_temp = norcpm_clim + 'bccr_1_2m_temperature_clim_median_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc' #Ver1.1.0
print('climatology file name: ', nc_clim_2m_temp)
nc_clim = netCDF4.Dataset(nc_clim_2m_temp, 'r')


## Check length of variables ##
nlon = len(nc.dimensions['longitude']) # Longitude: 360
nlat = len(nc.dimensions['latitude']) # Latitude: 180
nlead = len(nc.dimensions['time']) # Lead time: 6
nmem = len(nc.dimensions['number']) # Ensemble member: 60
#nlead = len(nc_clim.dimensions['lead'])
nmonth = len(nc_clim.dimensions['month']) # month: 12


## Read variables ##
### lead time ###
leads = nc_clim.variables['lead'][:] # lead time: 1-6

### month ###
months= nc_clim.variables['month'][:] # month: 1-12

### time ###
time = nc.variables['time'][:] # "hours since 1900-01-01 00:00:00.0"
timeunits = nc.variables['time'].getncattr('units')
timecalendar = nc.variables['time'].getncattr('calendar')
date_t2m = netCDF4.num2date(time,timeunits,timecalendar)
date_t2m = [ f'{i.year}-{i.month:02}' for i in date_t2m]

### latitude, longitude ###
lons = nc.variables['longitude'][:] # Longitude: -179.5 - 179.5
lats = nc.variables['latitude'][:] # Latitude: 89.5 - -89.5


### 2 meter temperature ### 
t2m = nc.variables['t2m'][:][:][:][:] #lead time, ensemble member, latitude, longitude

### 2m temperature climatology between 1993-2020 ###
clim_t2m = nc_clim.variables['clim_2m_temp'][:][:][:][:] #lead time, month, latitude, longitude



## Convert variables into Numpy ndarray ##
t2m_np = np.array(t2m)
clim_t2m_np = np.array(clim_t2m)



##--- Loop over all lead time --------------------------------##
for lead_mon in range(0, 4):
    lead_mon_range = range(lead_mon, lead_mon + 3)
    lead_mon_range_list = list(lead_mon_range)
    nlead_mon_range = len(lead_mon_range_list)
    llead_range = range(0, nlead_mon_range)

    t2m_prob_2d = np.zeros((nlead_mon_range, nlat, nlon))
    for lead, llead in zip(lead_mon_range, llead_range):
        print(lead, llead)
        ##--- Calculate probability of 2m temperature ------------------------##
        '''
        Count number of members that exceed climatology
        '''
        compare_clim_2d = np.zeros((nlat, nlon))
        for lat in range(0, nlat):
            for lon in range(0, nlon):     
                temp_2m = t2m_np[lead, :, lat, lon]
                clim_temp_2m = clim_t2m_np[lead, mon, lat, lon]

                diff = temp_2m - clim_temp_2m

                ### Count number of members ###
                count = np.sum(np.where(diff >= 0,1,0),axis=0)

                ### Calculate percentage ###
                probability = count/60 * 100
            
                ### Add results into an empty ndarray ###
                compare_clim_2d[lat, lon] = probability
                #sys.exit()
            #sys.exit()
        #sys.exit()
        t2m_prob_2d[llead, :, :] = compare_clim_2d 

    mean_t2m_prob = np.mean(t2m_prob_2d, axis=0 )


    '''
    ##--- Calculate average between 40N and 70N --------------------------##

    lat40 = 40; lat70 = 70
    lat_40_ind = np.where(lats < lat40)[0][0] #index = 50
    lat_70_ind = np.where(lats <= lat70)[0][0] #index = 20
    area_40_70 = compare_clim_2d[lat_70_ind : lat_40_ind, :]
    ave_40_70 = round(np.mean(area_40_70))
    print('Average between 40N and 70N: ', ave_40_70)
    print('')
    '''

    ##--- Call function: smoothing --------------------------------------##
    smooth_mean_t2m_prob = box_smooth_2D(mean_t2m_prob, 1, 1, latitude=np.array(lats))



    ##--- Interpolate ---------------------------------------------------##
    nsplines = 4

    lons4 = np.linspace(lons[0],lons[-1],lons.shape[0]*nsplines)
    lats4 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines)


    
    ### Use RegularGridInterpolator ###
    '''
    #(180, 360)->(720, 1440)
    '''
    data_grid = tuple([lons, lats])
    r = RGI(data_grid, smooth_mean_t2m_prob.T, method='linear', bounds_error=False)
    lon_new, lat_new = np.meshgrid(lons4, lats4, indexing='ij', sparse=True)
    inter_mean_t2m_prob = r((lon_new, lat_new))
    smooth_mean_t2m_prob_interp = inter_mean_t2m_prob.T




    ##--- Plot probability of 2m temperature -----------------------------##
    fig = plt.figure(figsize=(11,9))
    ax = fig.add_subplot(111)

    m = Basemap(projection='cyl', llcrnrlon=-20, urcrnrlon=50, llcrnrlat=35, urcrnrlat=72, resolution = 'l', fix_aspect = False) # <---Europe
    #m = Basemap() # <---Global
    
    plt.subplots_adjust(left=0.03, right=0.97, bottom=0.07, top=0.9) #ver1.0.1

    #m.drawmeridians(np.arange(-180, 180, 10)) # Draw longitude
    #m.drawparallels(np.arange(-90, 90, 10)) # Draw latitude
    m.drawcoastlines()
    m.drawcountries(color = 'grey')

    ### Create grid data ###
    #lon_grid, lat_grid = np.meshgrid(lons, lats) # <--- Original grid points
    lon_grid, lat_grid = np.meshgrid(lons4, lats4) # <--- Smoothing & Interpolation
    x, y = m(lon_grid, lat_grid)

    ### Set contour levels ###
    clevs = np.arange(0,110,10)

    ### Set colour map ###
    cmap = plt.cm.colors.ListedColormap(['#071e6e', '#054b89', '#2b7ea8', '#7db1cb', '#d0e1e9', '#edd2c4', '#d8a07e', '#c5703d', '#a43900', '#6f0500'])

    ### Draw map ###
    #cf = m.contourf(x, y, mean_t2m_prob, levels = clevs, cmap = cmap) # <--- Original grid points
    cf = m.contourf(x, y, smooth_mean_t2m_prob_interp, levels = clevs, cmap = cmap) # <--- Smoothing & Interpolation


    ### Add colorbar ###
    cbar = m.colorbar(cf, location='bottom', pad=0.1, size='3%')
    cbar.set_label('Probability (%)', fontsize = 15)
    cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]))
    cbar.ax.tick_params(labelsize=15)

    ### Add title ###
    month1 = datetime.datetime.strptime(date_t2m[lead-2], '%Y-%m').strftime('%B')
    month2 = datetime.datetime.strptime(date_t2m[lead-1], '%Y-%m').strftime('%B')
    month3 = datetime.datetime.strptime(date_t2m[lead], '%Y-%m').strftime('%B')
    

    fig_title = 'Estimated probability that ' + month1 + '/' + month2 + '/' + month3 + ' will be warmer than normal'
    ax_title = '(forecast start: ' + forecast_start_date +')' #ver1.0.1

    fig.suptitle(fig_title, fontsize = 18) #ver1.0.1
    ax.set_title(ax_title, loc="right", fontsize = 16) #ver1.0.1


    '''
    ### Add text ###
    fig_text = 'Issued: ' + issue_date
    ax_pos = ax.get_position()
    fig.text(ax_pos.x1 - 0.08, ax_pos.y1 - 0.91, fig_text, fontsize = 10)
    '''


    ### Save figures ###
    figs_dir = out_fig_dir + '/' + str(start_year) + str(start_month) + '/'
    if not os.path.exists(figs_dir):
        os.makedirs(figs_dir)

    lead_start = str(leads[lead-2])
    lead_end = str(leads[lead])


    ### File name ###
    '''
    <bccr system number>_<variable name>_<aggregation>_<lead time>
    '''
    #fig_name1 = figs_dir + 'bccr_' + system_num + '_LM' + lead_start + '-' + lead_end + '_2m_temperature_exceedq50_europe.png' # <--- temp forlder
    fig_name1 = figs_dir + 'bccr_' + system_num + '_2m_temperature_exceedq50_europe_s' + issue_date + '_3m_lead' + str(lead_start) + '.png' # #ver2.0.0--- archiving figure
    fig.savefig(fig_name1, dpi=300, format='png')



    plt.show()


##--- Close netCDF file ---------------------------------------##
nc.close()
nc_clim.close()


print('Completed!!')
