# Calculate and save probability of 2m temperature
# Compare forecast with climatology
# Save results as NetCDF files
# Period for climatology: Copernicus (1993-2016)
#
#
# Script version 1.3
#
# Script works with:
#   Python version: 3.10.13
#
#   Package version
#     numpy:  1.26.4
#     pandas:  2.2.0
#     netCDF4:  1.6.5
#     scipy:  1.12.0
#
#
# Ver1.0.0: Created by Mariko Koseki, 11.04.2025
# Ver1.1: updated by Mariko, 05.06.2025
# ver1.2: updated by Mariko, 20.10.2025
# ver1.3: updated by Mariko, 08.12.2025



import netCDF4
import numpy as np
from numpy import dtype
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
import calendar
import os
import sys
import platform
import scipy
from scipy.interpolate import RegularGridInterpolator as RGI



print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('netCDF4: ', netCDF4.__version__)
print('scipy: ', scipy.__version__)




##--- Function -----------------------------------------##
def box_smooth_2D(field,n_ih,n_jh,lat_wgt=True,latitude=None):
    
    ## The function is quoted from: 
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/tree/main/cf_monthly_forecast
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/blob/main/notebooks/002_forecast_plots.py 
    
    """
    Smooth a 2d field w/ dimensions (i,j) by applying a simple box filter (2d-equivalent to a running-mean)
    assumes cyclic j-coordinate (e.g., longitude for global data)
    note: n_ih = 1, n_jh = 1 will create averages of the central points and the 8 points around each 
    INPUT:
        field       : (numpy.array) with 2 dimensions (i,j)
        n_ih        : (int) number of grid points to take into account in positive & negative i-direction of the central point
        n_jh        : (int) number of grid points to take into account in positive & negative j-direction of the central point
        lat_wgt     : (boolean) whether to apply latitude weights in i-direction or not; if True, need to define latitude
        latitude    : (numpy.array) of length field.shape[0] (length of i-dimension) with latitudes used for weighting
    """
    
    if lat_wgt:
        assert latitude is not None, 'need to specify latitude vector of length N_i to apply latitude weighting'

    N_i,N_j = field.shape

    # pad the field in j direction (longitude):
    field_padded = np.concatenate([field[:,N_j-n_jh:],field,field[:,:n_jh]],axis=1)

    N_jp = field_padded.shape[1]

    box_kernel = np.ones([n_ih*2+1,n_jh*2+1]) # un-normalized kernel
    # use a copy of the original and overwrite, note that this means that edge points will keep un-smoothed values!
    smooth_field = field.copy()
    for i in range(N_i):
        for j in range(N_jp):
            if (i > n_ih) and (j > n_jh) and (i < N_i-n_ih-1) and (j < N_jp-n_jh-1):
                if lat_wgt:
                    lat_weight = np.cos(latitude[i-n_ih:i+n_ih+1]/180*np.pi)
                    wgts = lat_weight * box_kernel
                else:
                    wgts = box_kernel
                smooth_field[i,j-n_jh] = (wgts*field_padded[i-n_ih:i+n_ih+1,j-n_jh:j+n_jh+1]).sum()/wgts.sum()
    return smooth_field






##--- Set year, month for climatology ------------------------------------##
'''
Set 'start_year_clm' and 'end_year_clm'
Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016



##--- Set path ------------------------------------------##
## Set path to input files ## # ver1.3
'''
Set '' (path to input NetCDF file)
'''
norcpm_monthly = '/nird/datapeak/NS9873K/www/norcpm/forecast/monthly/'
norcpm_clim = '/nird/datapeak/NS9873K/norcpm/validation/2m_temperature/median/clim/'


## Set path to output files ##
'''
Set 'out_dir' (path to output file)
'''
#out_dir = '/nird/datalake/NS9039K/users/mariko/shared/norcpm/data/2m_temperature/'
out_dir = '/nird/datalake/NS9039K/users/mariko/shared/norcpm/data2/2m_temperature/' #ver1.1
if not os.path.exists(out_dir):
    os.makedirs(out_dir)



## Select NorCPM system ##
'''
Set system number 'system_num'
'''
system_num = '1'
#system_num = '2'

print('')
print('system: ', system_num)
print('')

##--- Input ----------------------------------------------##
args = sys.argv
if len(args) == 2:
    input_date = args[1]
    if len(input_date) == 6:
        yyyy = int(input_date[0:4])
        mm = int(input_date[4:6])
        dd = 1

    else:
        print('')
        print('---How to use this script---')
        print('python save_probability_2mt.py <yyyymm>')
        print('<yyyymm> = Forecast start date')
        print('Example: python save_probability_2mt.py 202312')
        print('')
        sys.exit()

else:
    print('')
    print('---How to use this script---')
    print('python python save_probability_2mt.py <yyyymm>')
    print('<yyyymm> = Forecast start date')
    print('Example: python save_probability_2mt.py 202312')
    print('')
    sys.exit()


start_year = '{:0>4d}'.format(yyyy)
start_month = '{:0>2d}'.format(mm)
start_date = '{:0>2d}'.format(dd)

issue_date = start_year + start_month + start_date
print('')
print('Issued: ', issue_date)
print('')

forecast_start_date = '{:0>2d}/{:0>2d}/{:0>4d}'.format(dd,mm,yyyy)
print('Forecast start: ', forecast_start_date)
print('')

mon = mm - 1




##--- Read netCDF files ----------------------------------##
## Open NetCDF files ##
### Forecast ###
nc_forecast = norcpm_monthly + start_year + start_month + '/2m_temperature_bccr_' \
        + system_num + '_' + start_year + '_' + start_month + '.nc'
print('forecast file name: ', nc_forecast)
nc = netCDF4.Dataset(nc_forecast, 'r')


### Climatology ###
#nc_clim_2m_temp = norcpm_clim + 'bccr_' + system_num + '_2m_temperature_clim_median_1993_2020.nc'
nc_clim_2m_temp = norcpm_clim + 'bccr_1_2m_temperature_clim_median_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc'
print('climatology file name: ', nc_clim_2m_temp)
nc_clim = netCDF4.Dataset(nc_clim_2m_temp, 'r')


## Check length of variables ##
nlon = len(nc.dimensions['longitude']) # Longitude: 360
nlat = len(nc.dimensions['latitude']) # Latitude: 180
nlead = len(nc.dimensions['time']) # Lead time: 6
nmem = len(nc.dimensions['number']) # Ensemble member: 60
#nlead = len(nc_clim.dimensions['lead'])
nmonth = len(nc_clim.dimensions['month']) # month: 12


## Read variables ##
### lead time ###
#leads = nc_clim.variables['lead'][:] # lead time: 1-6

### month ###
months= nc_clim.variables['month'][:] # month: 1-12

'''
### time ###
time = nc.variables['time'][:] # "hours since 1900-01-01 00:00:00.0"
timeunits = nc.variables['time'].getncattr('units')
timecalendar = nc.variables['time'].getncattr('calendar')
date_t2m = netCDF4.num2date(time,timeunits,timecalendar)
date_t2m = [ f'{i.year}-{i.month:02}' for i in date_t2m]
'''



### latitude, longitude ###
lons = nc.variables['longitude'][:] # Longitude: -179.5 - 179.5
lats = nc.variables['latitude'][:] # Latitude: 89.5 - -89.5


##--- Loop over all lead time --------------------------------##
for lead in range(0, nlead):

    ### 2 meter temperature ### 
    t2m = nc.variables['t2m'][lead][:][:][:] #lead time, ensemble member, latitude, longitude

    ### 2m temperature climatology between 1993-2020 ###
    clim_t2m = nc_clim.variables['clim_2m_temp'][lead][mon][:][:] #lead time, month, latitude, longitude

    ### lead time ###
    lead_num = nc_clim.variables['lead'][lead] # lead time: 1-6

    ### time ### #ver1.1
    tim = nc.variables['time'][lead] # "hours since 1900-01-01 00:00:00.0"
    timeunits = nc.variables['time'].getncattr('units')
    timecalendar = nc.variables['time'].getncattr('calendar')
    date = netCDF4.num2date(tim,timeunits,timecalendar)
    
    ## Convert variables into Numpy ndarray ##
    t2m_np = np.array(t2m)
    clim_t2m_np = np.array(clim_t2m)


    '''
    ## Close netCDF file ##
    nc.close()
    nc_clim.close()
    '''


    ##--- Calculate probability of 2m temperature ------------------------##
    '''
    Count number of members that exceed climatology
    '''
    compare_clim_2d = np.zeros((nlat, nlon))
    for lat in range(0, nlat):
        for lon in range(0, nlon):     
            temp_2m = t2m_np[:, lat, lon]
            clim_temp_2m = clim_t2m_np[lat, lon]

            diff = temp_2m - clim_temp_2m

            ### Count number of members ###
            count = np.sum(np.where(diff >= 0,1,0),axis=0)

            ### Calculate percentage ###
            probability = count/60 * 100
        
            ### Add results into an empty ndarray ###
            compare_clim_2d[lat, lon] = probability
            #sys.exit()
        #sys.exit()
    #sys.exit()




    ##--- Call function: smoothing --------------------------------------##
    smooth_compare_clim_2d = box_smooth_2D(compare_clim_2d, 1, 1, latitude=np.array(lats))
    #smooth_compare_clim_2d = compare_clim_2d



    ##--- Interpolate ---------------------------------------------------##
    '''
    1 degree -> 1/4 degree
    '''
    nsplines4 = 4

    lons4 = np.linspace(lons[0],lons[-1],lons.shape[0]*nsplines4)
    lats4 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines4)
    nlon4 = len(lons4); nlat4 = len(lats4)

    '''
    1 degree -> 1/2 degree
    '''
    nsplines2 = 2

    lons2 = np.linspace(lons[0],lons[-1],lons.shape[0]*nsplines2)
    lats2 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines2)
    nlon2 = len(lons2); nlat2 = len(lats2)


    ### Use RegularGridInterpolator ###
    '''
    (180, 360)->(720, 1440): 1degree -> quarter (1/4) degree
    '''
    data_grid4 = tuple([lons, lats])
    r4 = RGI(data_grid4, smooth_compare_clim_2d.T, method='linear', bounds_error=False)
    lon_new4, lat_new4 = np.meshgrid(lons4, lats4, indexing='ij', sparse=True)
    inter_compare_clim_2d_4 = r4((lon_new4, lat_new4))
    smooth_compare_clim_2d_interp4 = inter_compare_clim_2d_4.T


    '''
    (180, 360)->(360, 720): 1degree -> half (1/2) degree
    '''
    data_grid2 = tuple([lons, lats])
    r2 = RGI(data_grid2, smooth_compare_clim_2d.T, method='linear', bounds_error=False)
    lon_new2, lat_new2 = np.meshgrid(lons2, lats2, indexing='ij', sparse=True)
    inter_compare_clim_2d_2 = r2((lon_new2, lat_new2))
    smooth_compare_clim_2d_interp2 = inter_compare_clim_2d_2.T



    #print(np.amax(smooth_compare_clim_2d_interp, axis = (0, 1)))

    

    ##--- Save data --------------------------------------------------------##
    '''
    data_dir4 = out_dir + '/' + str(start_year) + str(start_month) + '/quarter/'
    data_dir2 = out_dir + '/' + str(start_year) + str(start_month) + '/half/'
    data_dir = out_dir + '/' + str(start_year) + str(start_month) + '/one/'
    '''
    data_dir4 = out_dir + '/quarter/' #ver1.1
    data_dir2 = out_dir + '/half/'
    data_dir = out_dir + '/one/'
    if not os.path.exists(data_dir4):
        os.makedirs(data_dir4)
    if not os.path.exists(data_dir2):
        os.makedirs(data_dir2)
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)


    '''
    quarter (1/4) degree
    '''
    ## File name ##
    #outfile4 = data_dir4 + 'bccr_1_2m_temp_prob_s' + issue_date + '_1m_lead' + str(lead_num) +'_025.nc'
    outfile4 = data_dir4 + 'bccr_1_2m_temp_prob_1m_lead' + str(lead_num) +'_025.nc'
    nc4 = netCDF4.Dataset(outfile4 ,'w', format="NETCDF4")


    ## Define dimensions ##
    nc4.createDimension('latitude', nlat4)
    nc4.createDimension('longitude', nlon4)
    nc4.createDimension('time', None) #ver1.1


    ## Define variables ## 
    ### time ### #ver1.1
    time = nc4.createVariable('time', dtype('double').char, ('time',))
    time.long_name = 'time'
    time.units = timeunits
    time.calendar = timecalendar
    time[:] = tim

    ### latitude ###
    lat = nc4.createVariable('latitude', dtype('float32').char, ('latitude',))
    lat.units = 'degrees_north'
    lat.long_name = 'latitude'
    lat[:] = lats4

    ### longitude ###
    lon = nc4.createVariable('longitude', dtype('float32').char, ('longitude',))
    lon.units = 'degrees_east'
    lon.long_name = 'longitude'
    lon[:] = lons4

    ### 2m temperature probability ###
    pre = nc4.createVariable('temp_2m_prob', dtype('float32').char, ('latitude', 'longitude',))
    pre.units = 'percentage'
    pre.long_name = '2m temperature probability'
    pre[:,:] = smooth_compare_clim_2d_interp4


    ### close NetCDF file ###
    nc4.close()



    '''
    half (1/2) degree
    '''
    ## File name ##
    #outfile2 = data_dir2 + 'bccr_1_2m_temp_prob_s' + issue_date + '_1m_lead' + str(lead_num) +'_050.nc'
    outfile2 = data_dir2 + 'bccr_1_2m_temp_prob_1m_lead' + str(lead_num) +'_050.nc'
    nc2 = netCDF4.Dataset(outfile2 ,'w', format="NETCDF4")


    ## Define dimensions ##
    nc2.createDimension('latitude', nlat2)
    nc2.createDimension('longitude', nlon2)
    nc2.createDimension('time', None) #ver1.1

    ## Define variables ##
    ### time ### #ver1.1
    time = nc2.createVariable('time', dtype('double').char, ('time',))
    time.long_name = 'time'
    time.units = timeunits
    time.calendar = timecalendar
    time[:] = tim

    ### latitude ###
    lat = nc2.createVariable('latitude', dtype('float32').char, ('latitude',))
    lat.units = 'degrees_north'
    lat.long_name = 'latitude'
    lat[:] = lats2

    ### longitude ###
    lon = nc2.createVariable('longitude', dtype('float32').char, ('longitude',))
    lon.units = 'degrees_east'
    lon.long_name = 'longitude'
    lon[:] = lons2

    ### 2m temperature probability ###
    pre = nc2.createVariable('temp_2m_prob', dtype('float32').char, ('latitude', 'longitude',))
    pre.units = 'percentage'
    pre.long_name = '2m temperature probability'
    pre[:,:] = smooth_compare_clim_2d_interp2


    ### close NetCDF file ###
    nc2.close()


    '''
    one (1) degree
    '''
    ## File name ##
    #outfile1 = data_dir + 'bccr_1_2m_temp_prob_s' + issue_date + '_1m_lead' + str(lead_num) +'_100.nc'
    outfile1 = data_dir + 'bccr_1_2m_temp_prob_1m_lead' + str(lead_num) +'_100.nc'
    nc1 = netCDF4.Dataset(outfile1 ,'w', format="NETCDF4")


    ## Define dimensions ##
    nc1.createDimension('latitude', nlat)
    nc1.createDimension('longitude', nlon)
    nc1.createDimension('time', None) #ver1.1

    ## Define variables ##
    ### time ### #ver1.1
    time = nc1.createVariable('time', dtype('double').char, ('time',))
    time.long_name = 'time'
    time.units = timeunits
    time.calendar = timecalendar
    time[:] = tim

    ### latitude ###
    lat = nc1.createVariable('latitude', dtype('float32').char, ('latitude',))
    lat.units = 'degrees_north'
    lat.long_name = 'latitude'
    lat[:] = lats

    ### longitude ###
    lon = nc1.createVariable('longitude', dtype('float32').char, ('longitude',))
    lon.units = 'degrees_east'
    lon.long_name = 'longitude'
    lon[:] = lons

    ### 2m temperature probability ###
    pre = nc1.createVariable('temp_2m_prob', dtype('float32').char, ('latitude', 'longitude',))
    pre.units = 'percentage'
    pre.long_name = '2m temperature probability'
    pre[:,:] = smooth_compare_clim_2d


    ### close NetCDF file ###
    nc1.close()




##--- Close netCDF file ---------------------------------------##
nc.close()
nc_clim.close()


print('')
print('Completed!!')
