# Calculate probability of precipitation
# Compare forecast with climatology
# Plot probability of precipitation
# Period for climatology: Copernicus (1993-2016)
#
#
# Script version 1.2
#
# Script works with:
#   Python version: 3.10
#
#   Package version
#     numpy:  1.26.4
#     pandas:  2.2.2
#     netCDF4:  1.6.5
#     scipy:  1.13.1
#
# Ver1.0.0: Created by Mariko Koseki, 11.09.2024
# Ver1.0.1: Updated by Mariko, 13.09.2024
# ver1.1: updated by Mariko, 08.12.2025
# ver1.2: updated by Mariko, 02.03.2026


import netCDF4
import numpy as np
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
import calendar
import os
import sys
import platform
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import scipy
from scipy.interpolate import RegularGridInterpolator as RGI


print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('netCDF4: ', netCDF4.__version__)
print('scipy: ', scipy.__version__)




##--- Function -----------------------------------------##
def box_smooth_2D(field,n_ih,n_jh,lat_wgt=True,latitude=None):
    
    ## The function is quoted from: 
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/tree/main/cf_monthly_forecast
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/blob/main/notebooks/002_forecast_plots.py 
    
    """
    Smooth a 2d field w/ dimensions (i,j) by applying a simple box filter (2d-equivalent to a running-mean)
    assumes cyclic j-coordinate (e.g., longitude for global data)
    note: n_ih = 1, n_jh = 1 will create averages of the central points and the 8 points around each 
    INPUT:
        field       : (numpy.array) with 2 dimensions (i,j)
        n_ih        : (int) number of grid points to take into account in positive & negative i-direction of the central point
        n_jh        : (int) number of grid points to take into account in positive & negative j-direction of the central point
        lat_wgt     : (boolean) whether to apply latitude weights in i-direction or not; if True, need to define latitude
        latitude    : (numpy.array) of length field.shape[0] (length of i-dimension) with latitudes used for weighting
    """
    
    if lat_wgt:
        assert latitude is not None, 'need to specify latitude vector of length N_i to apply latitude weighting'

    N_i,N_j = field.shape

    # pad the field in j direction (longitude):
    field_padded = np.concatenate([field[:,N_j-n_jh:],field,field[:,:n_jh]],axis=1)

    N_jp = field_padded.shape[1]

    box_kernel = np.ones([n_ih*2+1,n_jh*2+1]) # un-normalized kernel
    # use a copy of the original and overwrite, note that this means that edge points will keep un-smoothed values!
    smooth_field = field.copy()
    for i in range(N_i):
        for j in range(N_jp):
            if (i > n_ih) and (j > n_jh) and (i < N_i-n_ih-1) and (j < N_jp-n_jh-1):
                if lat_wgt:
                    lat_weight = np.cos(latitude[i-n_ih:i+n_ih+1]/180*np.pi)
                    wgts = lat_weight * box_kernel
                else:
                    wgts = box_kernel
                smooth_field[i,j-n_jh] = (wgts*field_padded[i-n_ih:i+n_ih+1,j-n_jh:j+n_jh+1]).sum()/wgts.sum()
    return smooth_field



##--- Set year, month for climatology ------------------------------------##
'''
Set 'start_year_clm' and 'end_year_clm'
Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016



##--- Set path ------------------------------------------##
## Set path to input files ## # ver1.1
'''
Set '' (path to input NetCDF file)
'''
norcpm_monthly = '/nird/datapeak/NS9873K/www/norcpm/forecast/monthly/' # ver1.1
norcpm_clim = '/nird/datapeak/NS9873K/norcpm/validation/precipitation/median/clim/' # ver1.1


## Set path to output files ## # ver1.1
'''
Set 'out_fig_dir' (path to output file)
'''
#out_fig_dir  = '/nird/datapeak/NS9873K/norcpm/validation/precipitation/median/fig/smooth/clm_' + str(start_year_clm) + '_' + str(end_year_clm) + '/' # <--- tmp forlder
out_fig_dir  = '/nird/datapeak/NS9873K/www/norcpm/forecast/plots/version2/' # ver1.1 <--- folder for archiving figures
if not os.path.exists(out_fig_dir):
    os.makedirs(out_fig_dir)


## Select NorCPM system ##
'''
Set system number 'system_num'
'''
system_num = '1'
#system_num = '2'

print('')
print('system: ', system_num)
print('')

##--- Input ----------------------------------------------##
args = sys.argv
if len(args) == 3: #ver1.2
    input_date = args[1]
    area = args[2]
    if len(input_date) == 6:
        yyyy = int(input_date[0:4])
        mm = int(input_date[4:6])
        dd = 1

    else:
        print('')
        print('---How to use this script---')
        print('python ', args[0], '<yyyymm> <area>')
        print('<yyyymm> = Forecast start date, <area> = europe or global')
        print('Example: python ', args[0], '202312 europe')
        print('')
        sys.exit()

else:
    print('')
    print('---How to use this script---')
    print('python ', args[0], '<yyyymm> <area>')
    print('<yyyymm> = Forecast start date, <area> = europe or global')
    print('Example: python ', args[0], '202312 europe')
    print('')
    sys.exit()


start_year = '{:0>4d}'.format(yyyy)
start_month = '{:0>2d}'.format(mm)
start_date = '{:0>2d}'.format(dd)

issue_date = start_year + start_month + start_date
print('')
print('Issued: ', issue_date)
print('')

forecast_start_date = '{:0>2d}/{:0>2d}/{:0>4d}'.format(dd,mm,yyyy)
print('Forecast start: ', forecast_start_date)
print('')

mon = mm - 1

ndate = calendar.monthrange(int(yyyy), int(mm))[1]
if calendar.isleap(int(yyyy)):
    ndate = 28



##--- Read netCDF files ----------------------------------##
## Open NetCDF files ##
### Forecast ###
nc_forecast = norcpm_monthly + start_year + start_month + '/total_precipitation_bccr_' \
        + system_num + '_' + start_year + '_' + start_month + '.nc'
print('forecast file name: ', nc_forecast)
nc = netCDF4.Dataset(nc_forecast, 'r')


### Climatology ###
nc_clim_prec = norcpm_clim + 'bccr_1_total_precipitation_clim_median_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc'
print('climatology file name: ', nc_clim_prec)
nc_clim = netCDF4.Dataset(nc_clim_prec, 'r')


## Check length of variables ##
nlon = len(nc.dimensions['longitude']) # Longitude: 360
nlat = len(nc.dimensions['latitude']) # Latitude: 180
nlead = len(nc.dimensions['time']) # Lead time: 6
nmem = len(nc.dimensions['number']) # Ensemble member: 60
#nlead = len(nc_clim.dimensions['lead'])
nmonth = len(nc_clim.dimensions['month']) # month: 12


## Read variables ##
### lead time ###
#leads = nc_clim.variables['lead'][:] # lead time: 1-6

### month ###
months= nc_clim.variables['month'][:] # month: 1-12

### time ###
time = nc.variables['time'][:] # "hours since 1900-01-01 00:00:00.0"
timeunits = nc.variables['time'].getncattr('units')
timecalendar = nc.variables['time'].getncattr('calendar')
date_prec = netCDF4.num2date(time,timeunits,timecalendar)
date_prec = [ f'{i.year}-{i.month:02}' for i in date_prec]

### latitude, longitude ###
lons = nc.variables['longitude'][:] # Longitude: -179.5 - 179.5
lats = nc.variables['latitude'][:] # Latitude: 89.5 - -89.5


##--- Loop over all lead time --------------------------------##
for lead in range(0, nlead):

    ### precipitation ### 
    tprate = nc.variables['tprate'][lead][:][:][:] #lead time, ensemble member, latitude, longitude

    ### precipitation climatology between 1993-2016 ###
    clim_prec = nc_clim.variables['clim_prec'][lead][mon][:][:] #lead time, month, latitude, longitude

    ### lead time ###
    lead_num = nc_clim.variables['lead'][lead] # lead time: 1-6

    
    ## Convert variables into Numpy ndarray ##
    tprate_np = np.array(tprate)
    clim_prec_np = np.array(clim_prec)


    ## Change unit: m/s -> mm/month
    tprate2 = tprate_np * 1000 * 3600 * 24 * ndate

    '''
    ## Close netCDF file ##
    nc.close()
    nc_clim.close()
    '''


    ##--- Calculate probability of precipitation ------------------------##
    '''
    Count number of members that exceed climatology
    '''
    compare_clim_2d = np.zeros((nlat, nlon))
    for lat in range(0, nlat):
        for lon in range(0, nlon):     
            prc = tprate2[:, lat, lon]
            clim_prc = clim_prec_np[lat, lon]

            diff = prc - clim_prc

            ### Count number of members ###
            count = np.sum(np.where(diff >= 0,1,0),axis=0)

            ### Calculate percentage ###
            probability = count/60 * 100
        
            ### Add results into an empty ndarray ###
            compare_clim_2d[lat, lon] = probability
            #sys.exit()
        #sys.exit()
    #sys.exit()


    
    ##--- Call function: smoothing --------------------------------------##
    smooth_compare_clim_2d = box_smooth_2D(compare_clim_2d, 1, 1, latitude=np.array(lats))
    #smooth_compare_clim_2d = compare_clim_2d



    ##--- Interpolate ---------------------------------------------------##
    nsplines = 4

    lons4 = np.linspace(lons[0],lons[-1],lons.shape[0]*nsplines)
    lats4 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines)



    
    ### Use RegularGridInterpolator ###
    '''
    (180, 360)->(720, 1440)
    '''
    data_grid = tuple([lons, lats])
    r = RGI(data_grid, smooth_compare_clim_2d.T, method='linear', bounds_error=False)
    lon_new, lat_new = np.meshgrid(lons4, lats4, indexing='ij', sparse=True)
    inter_compare_clim_2d = r((lon_new, lat_new))
    smooth_compare_clim_2d_interp = inter_compare_clim_2d.T


    #print(np.amax(smooth_compare_clim_2d_interp, axis = (0, 1)))


    ##--- Plot probability of precipitation -----------------------------##
    fig = plt.figure(figsize=(11,9))
    ax = fig.add_subplot(111)

    if area == 'europe': #ver1.2
        m = Basemap(projection='cyl', llcrnrlon=-20, urcrnrlon=50, llcrnrlat=35, urcrnrlat=72, resolution = 'l', fix_aspect = False) # <---Europe
    elif area == 'global':
        m = Basemap() # <---Global

    plt.subplots_adjust(left=0.03, right=0.97, bottom=0.07, top=0.9)

    #m.drawmeridians(np.arange(-180, 180, 10)) # Draw longitude
    #m.drawparallels(np.arange(-90, 90, 10)) # Draw latitude
    m.drawcoastlines()
    m.drawcountries(color = 'grey')



    ### Create grid data ###
    #lon_grid, lat_grid = np.meshgrid(lons, lats) # <--- Original grid points
    lon_grid, lat_grid = np.meshgrid(lons4, lats4) # <--- Smoothing & Interpolation
    x, y = m(lon_grid, lat_grid)


    ### Set contour levels ###
    clevs = np.arange(0,110,10)

    ### Set colour map ###
    cmap = plt.cm.colors.ListedColormap(['#363608', '#686933', '#9c9d61', '#cbcd9a', '#eaebd7', '#d3dee8', '#95afc9', '#5781a9', '#295488', '#27275e'])

    ### Draw map ###
    #cf = m.contourf(x, y, compare_clim_2d, levels = clevs, cmap = cmap) # <--- Original grid points
    cf = m.contourf(x, y, smooth_compare_clim_2d_interp, levels = clevs, cmap = cmap) # <--- Smoothing & Interpolation
    
    ### Add colorbar ###
    cbar = m.colorbar(cf, location='bottom', pad=0.1, size='3%')
    cbar.set_label('Probability (%)', fontsize = 15)
    cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]))
    cbar.ax.tick_params(labelsize=15)

    ### Add title ###
    forecast_date = datetime.datetime(year = int(start_year), month = int(start_month), day=1) + relativedelta(months = lead)
    forecast_month = forecast_date.strftime('%B')
    forecast_year = forecast_date.strftime('%Y')

    fig_title = 'Estimated probability that ' + forecast_month + ' ' + forecast_year + ' will be wetter than normal'
    ax_title = '(forecast start: ' + forecast_start_date +')'

    fig.suptitle(fig_title, fontsize = 18)
    ax.set_title(ax_title, loc="right", fontsize = 16)



    
    ##--- Save figures --------------------------------------------------------##
    figs_dir = out_fig_dir + '/' + str(start_year) + str(start_month) + '/'
    if not os.path.exists(figs_dir):
        os.makedirs(figs_dir)


    ### File name ###
    '''
    <bccr system number>_<variable name>_<aggregation>_<lead time>
    '''
    #fig_name1 = figs_dir + 'bccr_' + system_num + '_LM' + str(lead_num) + '_precipitation_exceedq50_europe.png' # <--- temp forlder
    fig_name1 = figs_dir + 'bccr_' + system_num + '_precipitation_exceedq50_' + area + '_s' + issue_date + '_1m_lead' + str(lead_num) + '.png' #<--- archiving figure #ver1.2
    fig.savefig(fig_name1, dpi=300, format='png')
    

    plt.show()


##--- Close netCDF file ---------------------------------------##
nc.close()
nc_clim.close()


print('')
print('Completed!!')
