# Calculate anomaly of 2m temperature (NorCPM)
# Compare forecast with climatology
# Plot anomaly of 2m temperature
# Period for climatology -> Copernicus: 1993-2016
#
#
# Script version 1.1
#
# Script works with:
#   Python version: 3.10
#
#   Package version
#     numpy:  
#     pandas:  
#     netCDF4:  
#     scipy:  
#
# Ver1.0.0: Created by Mariko Koseki, 19.11.2024
# ver1.1: updated by Mariko, 20.10.2025



import netCDF4
import numpy as np
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
import os
import sys
import platform
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import scipy
from scipy.interpolate import RegularGridInterpolator as RGI


print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('netCDF4: ', netCDF4.__version__)
print('scipy: ', scipy.__version__)




##--- Function -----------------------------------------##
def box_smooth_2D(field,n_ih,n_jh,lat_wgt=True,latitude=None):
    
    ## The function is quoted from: 
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/tree/main/cf_monthly_forecast
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/blob/main/notebooks/002_forecast_plots.py 
    
    """
    Smooth a 2d field w/ dimensions (i,j) by applying a simple box filter (2d-equivalent to a running-mean)
    assumes cyclic j-coordinate (e.g., longitude for global data)
    note: n_ih = 1, n_jh = 1 will create averages of the central points and the 8 points around each 
    INPUT:
        field       : (numpy.array) with 2 dimensions (i,j)
        n_ih        : (int) number of grid points to take into account in positive & negative i-direction of the central point
        n_jh        : (int) number of grid points to take into account in positive & negative j-direction of the central point
        lat_wgt     : (boolean) whether to apply latitude weights in i-direction or not; if True, need to define latitude
        latitude    : (numpy.array) of length field.shape[0] (length of i-dimension) with latitudes used for weighting
    """
    
    if lat_wgt:
        assert latitude is not None, 'need to specify latitude vector of length N_i to apply latitude weighting'

    N_i,N_j = field.shape

    # pad the field in j direction (longitude):
    field_padded = np.concatenate([field[:,N_j-n_jh:],field,field[:,:n_jh]],axis=1)

    N_jp = field_padded.shape[1]

    box_kernel = np.ones([n_ih*2+1,n_jh*2+1]) # un-normalized kernel
    # use a copy of the original and overwrite, note that this means that edge points will keep un-smoothed values!
    smooth_field = field.copy()
    for i in range(N_i):
        for j in range(N_jp):
            if (i > n_ih) and (j > n_jh) and (i < N_i-n_ih-1) and (j < N_jp-n_jh-1):
                if lat_wgt:
                    lat_weight = np.cos(latitude[i-n_ih:i+n_ih+1]/180*np.pi)
                    wgts = lat_weight * box_kernel
                else:
                    wgts = box_kernel
                smooth_field[i,j-n_jh] = (wgts*field_padded[i-n_ih:i+n_ih+1,j-n_jh:j+n_jh+1]).sum()/wgts.sum()
    return smooth_field



##--- Set year, month for climatology ------------------------------------##
'''
Set 'start_year_clm' and 'end_year_clm'
Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016



##--- Set path ------------------------------------------##
## Set path to input files ##
'''
Set '' (path to input NetCDF file)
'''
norcpm_monthly = '/nird/projects/NS9873K/www/norcpm/forecast/monthly/'
norcpm_clim = '/nird/projects/NS9873K/norcpm/validation/2m_temperature/median/clim/'
#norcpm_clim = '/nird/projects/NS9873K/norcpm/validation/2m_temperature/mean/clim/' #ver1.1


## Set path to output files ##
'''
Set 'out_fig_dir' (path to output file)
'''
out_fig_dir  = '/nird/projects/NS9873K/www/norcpm/forecast/plots/version3/anomaly/' #ver1.1
if not os.path.exists(out_fig_dir):
    os.makedirs(out_fig_dir)


## Select NorCPM system ##
'''
Set system number 'system_num'
'''
system_num = '1'
#system_num = '2'

print('')
print('system: ', system_num)
print('')

##--- Input ----------------------------------------------##
args = sys.argv
if len(args) == 3: #ver1.1
    input_date = args[1]
    area = args[2]
    if len(input_date) == 6:
        yyyy = int(input_date[0:4])
        mm = int(input_date[4:6])
        dd = 1

    else:
        print('')
        print('---How to use this script---')
        print('python ', args[0], '<yyyymm> <area>')
        print('<yyyymm> = Forecast start date, <area> = europe or global')
        print('Example: python ', args[0], '202312 europe')
        print('')
        sys.exit()

else:
    print('')
    print('---How to use this script---')
    print('python ', args[0], '<yyyymm> <area>')
    print('<yyyymm> = Forecast start date, <area> = europe or global')
    print('Example: python ', args[0], '202312 europe')
    print('')
    sys.exit()


start_year = '{:0>4d}'.format(yyyy)
start_month = '{:0>2d}'.format(mm)
start_date = '{:0>2d}'.format(dd)

issue_date = start_year + start_month + start_date
print('')
print('Issued: ', issue_date)
print('')

forecast_start_date = '{:0>2d}/{:0>2d}/{:0>4d}'.format(dd,mm,yyyy)
print('Forecast start: ', forecast_start_date)
print('')

mon = mm - 1




##--- Read netCDF files ----------------------------------##
## Open NetCDF files ##
### Forecast ###
nc_forecast = norcpm_monthly + start_year + start_month + '/2m_temperature_bccr_' \
        + system_num + '_' + start_year + '_' + start_month + '.nc'
print('forecast file name: ', nc_forecast)
nc = netCDF4.Dataset(nc_forecast, 'r')


### Climatology ###
#nc_clim_2m_temp = norcpm_clim + 'bccr_' + system_num + '_2m_temperature_clim_median_1993_2020.nc'
nc_clim_2m_temp = norcpm_clim + 'bccr_1_2m_temperature_clim_median_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc'
#nc_clim_2m_temp = norcpm_clim + 'bccr_1_2m_temperature_clim_mean_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc' #ver1.1
print('climatology file name: ', nc_clim_2m_temp)
nc_clim = netCDF4.Dataset(nc_clim_2m_temp, 'r')


## Check length of variables ##
nlon = len(nc.dimensions['longitude']) # Longitude: 360
nlat = len(nc.dimensions['latitude']) # Latitude: 180
nlead = len(nc.dimensions['time']) # Lead time: 6
nmem = len(nc.dimensions['number']) # Ensemble member: 60
#nlead = len(nc_clim.dimensions['lead'])
nmonth = len(nc_clim.dimensions['month']) # month: 12


## Read variables ##
### lead time ###
#leads = nc_clim.variables['lead'][:] # lead time: 1-6

### month ###
months= nc_clim.variables['month'][:] # month: 1-12

### time ###
time = nc.variables['time'][:] # "hours since 1900-01-01 00:00:00.0"
timeunits = nc.variables['time'].getncattr('units')
timecalendar = nc.variables['time'].getncattr('calendar')
date_t2m = netCDF4.num2date(time,timeunits,timecalendar)
date_t2m = [ f'{i.year}-{i.month:02}' for i in date_t2m]

### latitude, longitude ###
lons = nc.variables['longitude'][:] # Longitude: -179.5 - 179.5
lats = nc.variables['latitude'][:] # Latitude: 89.5 - -89.5


##--- Loop over all lead time --------------------------------##
for lead in range(0, nlead):

    ### 2 meter temperature ### 
    t2m = nc.variables['t2m'][lead][:][:][:] #lead time, ensemble member, latitude, longitude

    ### 2m temperature climatology ###
    clim_t2m = nc_clim.variables['clim_2m_temp'][lead][mon][:][:] #lead time, month, latitude, longitude

    ### lead time ###
    lead_num = nc_clim.variables['lead'][lead] # lead time: 1-6

    
    ## Convert variables into Numpy ndarray ##
    t2m_np = np.array(t2m)
    clim_t2m_np = np.array(clim_t2m)


    '''
    ## Close netCDF file ##
    nc.close()
    nc_clim.close()
    '''


    ##--- Calculate anomaly of 2m temperature ------------------------##  
    compare_clim_2d = np.zeros((nlat, nlon))
    for lat in range(0, nlat):
        for lon in range(0, nlon):     
            temp_2m = t2m_np[:, lat, lon]
            clim_temp_2m = clim_t2m_np[lat, lon]
            
            ### Ensemble mean ###
            temp_2m_mean = np.mean(temp_2m)


            ### Calculate anomaly ###
            diff = temp_2m_mean - clim_temp_2m

        
            ### Add results into an empty ndarray ###
            compare_clim_2d[lat, lon] = diff
            #sys.exit()
        #sys.exit()
    #sys.exit()
    #print(compare_clim_2d)



    '''
    ##--- Call function: smoothing --------------------------------------##
    smooth_compare_clim_2d = box_smooth_2D(compare_clim_2d, 1, 1, latitude=np.array(lats))
    #smooth_compare_clim_2d = compare_clim_2d
    '''

    '''
    ##--- Interpolate ---------------------------------------------------##
    nsplines = 4

    lons4 = np.linspace(lons[0],lons[-1],lons.shape[0]*nsplines)
    lats4 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines)
    '''

    '''
    ### Use interp2d ###
    #//////////////////////////////////////////////////
    #/// DeprecationWarning
    #/// "interp2d" is deprecated in SciPy 1.10 and will be removed in SciPy 1.14.0.
    #/// Use `RegularGridInterpolator` instead.
    #//////////////////////////////////////////////////
    
    f = interpolate.interp2d(lons, lats, smooth_compare_clim_2d, kind='cubic') # The code does not work
    smooth_compare_clim_2d_interp = f(lons4,lats4)
    '''


    '''
    ### Use RegularGridInterpolator ###
    '''
    #(180, 360)->(720, 1440)
    '''
    data_grid = tuple([lons, lats])
    r = RGI(data_grid, smooth_compare_clim_2d.T, method='linear', bounds_error=False)
    lon_new, lat_new = np.meshgrid(lons4, lats4, indexing='ij', sparse=True)
    inter_compare_clim_2d = r((lon_new, lat_new))
    smooth_compare_clim_2d_interp = inter_compare_clim_2d.T
    '''

    #print(np.amax(smooth_compare_clim_2d_interp, axis = (0, 1)))


    ##--- Plot anomaly of 2m temperature -----------------------------##
    fig = plt.figure(figsize=(11,9))
    ax = fig.add_subplot(111)

    if area == 'europe': #ver1.1
        m = Basemap(projection='cyl', llcrnrlon=-20, urcrnrlon=50, llcrnrlat=35, urcrnrlat=72, resolution = 'l', fix_aspect = False) # <---Europe
    elif area == 'global':
        m = Basemap() # <---Global

    plt.subplots_adjust(left=0.03, right=0.97, bottom=0.07, top=0.9)

    #m.drawmeridians(np.arange(-180, 180, 10)) # Draw longitude
    #m.drawparallels(np.arange(-90, 90, 10)) # Draw latitude
    m.drawcoastlines()
    m.drawcountries(color = 'grey')



    ### Create grid data ###
    lon_grid, lat_grid = np.meshgrid(lons, lats) # <--- Original grid points
    #lon_grid, lat_grid = np.meshgrid(lons4, lats4) # <--- Smoothing & Interpolation
    x, y = m(lon_grid, lat_grid)


    ### Set contour levels ###
    #clevs = np.arange(-4.0,4.8,0.8)
    clevs = np.arange(-2.0,2.4,0.4)

    ### Set colour map ###
    cmap = plt.cm.colors.ListedColormap(['#071e6e', '#054b89', '#2b7ea8', '#7db1cb', '#d0e1e9', '#edd2c4', '#d8a07e', '#c5703d', '#a43900', '#6f0500'])
    cmap.set_over('#590007')
    cmap.set_under('#001261')

    ### Draw map ###
    cf = m.contourf(x, y, compare_clim_2d, levels = clevs, cmap = cmap, extend='both') # <--- Original grid points
    #cf = m.contourf(x, y, smooth_compare_clim_2d_interp, levels = clevs, cmap = cmap, extend='both') # <--- Smoothing & Interpolation
    
    ### Add colorbar ###
    cbar = m.colorbar(cf, location='bottom', pad=0.1, size='3%')
    cbar.set_label('Anomaly (\u2103)', fontsize = 15)
    #cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([-4.0, -3.2, -2.4, -1.6, -0.8, 0, 0.8, 1.6, 2.4, 3.2, 4.0]))
    cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([-2.0, -1.6, -1.2, -0.8, -0.4, 0, 0.4, 0.8, 1.2, 1.6, 2.0]))
    cbar.ax.tick_params(labelsize=15)

    ### Add title ###
    forecast_date = datetime.datetime(year = int(start_year), month = int(start_month), day=1) + relativedelta(months = lead)
    forecast_month = forecast_date.strftime('%B')
    forecast_year = forecast_date.strftime('%Y')

    fig_title = 'Temperature (2m) Ensemble Mean Anomaly ' + forecast_month + ' ' + forecast_year
    ax_title = '(forecast start: ' + forecast_start_date +')'

    fig.suptitle(fig_title, fontsize = 18)
    ax.set_title(ax_title, loc="right", fontsize = 16)




    
    ### Save figures ###
    figs_dir = out_fig_dir + '/' + str(start_year) + str(start_month) + '/'
    if not os.path.exists(figs_dir):
        os.makedirs(figs_dir)


    ### File name ###
    '''
    <bccr system number>_<variable name>_<aggregation>_<lead time>
    '''
    fig_name1 = figs_dir + 'bccr_' + system_num + '_2m_temperature_anomaly_' + area + '_s' + issue_date + '_1m_lead' + str(lead_num) + '.png' # <--- archiving figure #ver1.1
    fig.savefig(fig_name1, dpi=300, format='png')
    

    plt.show()


##--- Close netCDF file ---------------------------------------##
nc.close()
nc_clim.close()


print('')
print('Completed!!')
