# Calculate anomaly of total precipitation (ERA5)
# Compare era5 with climatology
# Plot anomaly of total precipitation
# Period for climatology -> Copernicus: 1993-2016
#
#
# Script version 1.1
#
# Script works with:
#   Python version: 3.10.13
#
#   Package version
#     numpy:  1.26.4
#     pandas:  2.2.3
#     netCDF4:  1.7.1
#     scipy:  1.14.0
#
# Ver1.0.0: Created by Mariko Koseki, 27.11.2024
# ver1.1: updated by Mariko, 20.10.2025



import netCDF4
import numpy as np
import pandas as pd
import datetime
from dateutil.relativedelta import relativedelta
import calendar
import os
import sys
import platform
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import scipy
from scipy.interpolate import RegularGridInterpolator as RGI


print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('netCDF4: ', netCDF4.__version__)
print('scipy: ', scipy.__version__)




##--- Function -----------------------------------------##
def box_smooth_2D(field,n_ih,n_jh,lat_wgt=True,latitude=None):
    
    ## The function is quoted from: 
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/tree/main/cf_monthly_forecast
    ## https://github.com/SeasonalForecastingEngine/cf_monthly_forecast/blob/main/notebooks/002_forecast_plots.py 
    
    """
    Smooth a 2d field w/ dimensions (i,j) by applying a simple box filter (2d-equivalent to a running-mean)
    assumes cyclic j-coordinate (e.g., longitude for global data)
    note: n_ih = 1, n_jh = 1 will create averages of the central points and the 8 points around each 
    INPUT:
        field       : (numpy.array) with 2 dimensions (i,j)
        n_ih        : (int) number of grid points to take into account in positive & negative i-direction of the central point
        n_jh        : (int) number of grid points to take into account in positive & negative j-direction of the central point
        lat_wgt     : (boolean) whether to apply latitude weights in i-direction or not; if True, need to define latitude
        latitude    : (numpy.array) of length field.shape[0] (length of i-dimension) with latitudes used for weighting
    """
    
    if lat_wgt:
        assert latitude is not None, 'need to specify latitude vector of length N_i to apply latitude weighting'

    N_i,N_j = field.shape

    # pad the field in j direction (longitude):
    field_padded = np.concatenate([field[:,N_j-n_jh:],field,field[:,:n_jh]],axis=1)

    N_jp = field_padded.shape[1]

    box_kernel = np.ones([n_ih*2+1,n_jh*2+1]) # un-normalized kernel
    # use a copy of the original and overwrite, note that this means that edge points will keep un-smoothed values!
    smooth_field = field.copy()
    for i in range(N_i):
        for j in range(N_jp):
            if (i > n_ih) and (j > n_jh) and (i < N_i-n_ih-1) and (j < N_jp-n_jh-1):
                if lat_wgt:
                    lat_weight = np.cos(latitude[i-n_ih:i+n_ih+1]/180*np.pi)
                    wgts = lat_weight * box_kernel
                else:
                    wgts = box_kernel
                smooth_field[i,j-n_jh] = (wgts*field_padded[i-n_ih:i+n_ih+1,j-n_jh:j+n_jh+1]).sum()/wgts.sum()
    return smooth_field



##--- Set year, month for climatology ------------------------------------##
'''
Set 'start_year_clm' and 'end_year_clm'
Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016



##--- Set path ------------------------------------------##
## Set path to input files ##
'''
Set '' (path to input NetCDF file)
'''
era5_monthly = '/nird/projects/NS9873K/norcpm/validation/reanalysis/ECMWF/ERA5/original/monthly_single_level/monthly_averaged_reanalysis/tp/'
era5_clim  = '/nird/projects/NS9873K/norcpm/validation/reanalysis/ECMWF/ERA5/clim/total_precipitation/'




## Set path to output files ##
'''
Set 'out_fig_dir' (path to output file)
'''
out_fig_dir  = '/nird/projects/NS9873K/www/norcpm/forecast/plots/version3/reference/'
if not os.path.exists(out_fig_dir):
    os.makedirs(out_fig_dir)




##--- Input ----------------------------------------------##
args = sys.argv
if len(args) == 3: #ver1.1
    input_date = args[1]
    area = args[2]
    if len(input_date) == 6:
        yyyy = int(input_date[0:4])
        mm = int(input_date[4:6])
        dd = 1

    else:
        print('')
        print('---How to use this script---')
        print('python ', args[0], '<yyyymm> <area>')
        print('<yyyymm> = Start date, <area> = europe or global')
        print('Example: python ', args[0], '202312 europe')
        print('')
        sys.exit()

else:
    print('')
    print('---How to use this script---')
    print('python ', args[0], '<yyyymm> <area>')
    print('<yyyymm> = Start date, <area> = europe or global')
    print('Example: python ', args[0], '202312 europe')
    print('')
    sys.exit()


start_year = '{:0>4d}'.format(yyyy)
start_month = '{:0>2d}'.format(mm)
start_date = '{:0>2d}'.format(dd)

issue_date = start_year + start_month + start_date
print('')
print('Issued: ', issue_date)
print('')

era5_start_date = '{:0>2d}/{:0>2d}/{:0>4d}'.format(dd,mm,yyyy)
print('ERA5 start: ', era5_start_date)
print('')

mon = mm - 1



ndate = calendar.monthrange(int(yyyy), int(mm))[1]
if calendar.isleap(int(yyyy)):
    ndate = 28



##--- Read netCDF files ----------------------------------##
## Open NetCDF files ##
'''
## Example: "ERA5_228_1993.nc" ##
dimensions:
	valid_time = 12 ;
	latitude = 721 ;
	longitude = 1440 ;
variables:
	int64 number ;
		number:long_name = "ensemble member numerical id" ;
		number:units = "1" ;
		number:standard_name = "realization" ;
	int64 valid_time(valid_time) ;
		valid_time:long_name = "time" ;
		valid_time:standard_name = "time" ;
		valid_time:units = "seconds since 1970-01-01" ;
		valid_time:calendar = "proleptic_gregorian" ;
	double latitude(latitude) ;
		latitude:_FillValue = NaN ;
		latitude:units = "degrees_north" ;
		latitude:standard_name = "latitude" ;
		latitude:long_name = "latitude" ;
		latitude:stored_direction = "decreasing" ;
	double longitude(longitude) ;
		longitude:_FillValue = NaN ;
		longitude:units = "degrees_east" ;
		longitude:standard_name = "longitude" ;
		longitude:long_name = "longitude" ;
	string expver(valid_time) ;
	float tp(valid_time, latitude, longitude) ;
		tp:_FillValue = NaNf ;
		tp:GRIB_paramId = 228LL ;
		tp:GRIB_dataType = "fc" ;
		tp:GRIB_numberOfPoints = 1038240LL ;
		tp:GRIB_typeOfLevel = "surface" ;
		tp:GRIB_stepUnits = 1LL ;
		tp:GRIB_stepType = "avgad" ;
		tp:GRIB_gridType = "regular_ll" ;
		tp:GRIB_uvRelativeToGrid = 0LL ;
		tp:GRIB_NV = 0LL ;
		tp:GRIB_Nx = 1440LL ;
		tp:GRIB_Ny = 721LL ;
		tp:GRIB_cfName = "unknown" ;
		tp:GRIB_cfVarName = "tp" ;
		tp:GRIB_gridDefinitionDescription = "Latitude/Longitude Grid" ;
		tp:GRIB_iDirectionIncrementInDegrees = 0.25 ;
		tp:GRIB_iScansNegatively = 0LL ;
		tp:GRIB_jDirectionIncrementInDegrees = 0.25 ;
		tp:GRIB_jPointsAreConsecutive = 0LL ;
		tp:GRIB_jScansPositively = 0LL ;
		tp:GRIB_latitudeOfFirstGridPointInDegrees = 90. ;
		tp:GRIB_latitudeOfLastGridPointInDegrees = -90. ;
		tp:GRIB_longitudeOfFirstGridPointInDegrees = 0. ;
		tp:GRIB_longitudeOfLastGridPointInDegrees = 359.75 ;
		tp:GRIB_missingValue = 3.40282346638529e+38 ;
		tp:GRIB_name = "Total precipitation" ;
		tp:GRIB_shortName = "tp" ;
		tp:GRIB_totalNumber = 0LL ;
		tp:GRIB_units = "m" ;
		tp:long_name = "Total precipitation" ;
		tp:units = "m" ;
		tp:standard_name = "unknown" ;
		tp:GRIB_surface = 0. ;
		tp:coordinates = "number valid_time latitude longitude expver" ;
'''



### ERA5 ###
nc_tprec = era5_monthly + 'ERA5_228_' + start_year + '.nc'

print('ERA5 file name: ', nc_tprec)
nc = netCDF4.Dataset(nc_tprec, 'r')


### Climatology ###
#nc_clim_tprec = era5_clim + 'era5_Total_precipitation_clim_mean_1993_2016.nc'
nc_clim_tprec = era5_clim + 'era5_Total_precipitation_clim_median_1993_2016.nc'
print('climatology file name: ', nc_clim_tprec)
nc_clim = netCDF4.Dataset(nc_clim_tprec, 'r')




## Check length of variables ##
nlon = len(nc.dimensions['longitude']) # Longitude: 1440 
nlat = len(nc.dimensions['latitude']) # Latitude: 721
nmonth = len(nc_clim.dimensions['month']) # month: 12
print(nlon, nlat, nmonth)



## Read variables ##
### month ###
months_clm = nc_clim.variables['month'][:] # month: 1-12
months_era5 = nc.variables['valid_time'][:] # month:
print('month clim: ', months_clm)
print('month era5: ', months_era5)


### latitude, longitude ###
lons = nc.variables['longitude'][:] # Longitude: 0 - 359.75
lats = nc.variables['latitude'][:] # Latitude: 90 - -90

lons2 = np.array([l - 360. if l > 180 else l for l in lons])

#print(lons)
#print(lons2.shape)
#print(lats)


### Total precipitation ### 
tprec = nc.variables['tp'][mon][:][:]

### Total precipitation climatology ###
clim_tprec = nc_clim.variables['clim_tp_era5'][mon][:][:]



    
## Convert variables into Numpy ndarray ##
tprec_np = np.array(tprec)
clim_tprec_np = np.array(clim_tprec)



## Concatenate Numpy ndarray ##
lon_end1 = int(nlon/2+1)
#print(lon_end1)
tprec_np1 = tprec_np[:, 0:lon_end1]
tprec_np2 = tprec_np[:, lon_end1:]
tprec_np_con = np.concatenate([tprec_np2, tprec_np1], 1)


clim_tprec_np1 = clim_tprec_np[:, 0:lon_end1]
clim_tprec_np2 = clim_tprec_np[:, lon_end1:]
clim_tprec_np_con = np.concatenate([clim_tprec_np2, clim_tprec_np1], 1)


lons2_1 = lons2[0:lon_end1]
lons2_2 = lons2[lon_end1:]
lons2_con = np.concatenate([lons2_2, lons2_1])


#print(t2m_np1.shape)
#print(t2m_np2.shape)
#print(t2m_np_con.shape)
#print(lons2_con)



## Close netCDF file ##
nc.close()
nc_clim.close()
#sys.exit()


##--- Calculate anomaly of total precipitation ------------------------##  
compare_clim_2d = np.zeros((nlat, nlon))
for lat in range(0, nlat):
    for lon in range(0, nlon):     
        prec = tprec_np_con[lat, lon]
        clim_prec = clim_tprec_np_con[lat, lon]


        ## Change unit: m/day -> mm/month
        prec2 = prec * 1000 * ndate
        clim_prec2 = clim_prec * 1000 * ndate


        ### Calculate anomaly ###
        diff = prec2 - clim_prec2

        
        ### Add results into an empty ndarray ###
        compare_clim_2d[lat, lon] = diff
        #sys.exit()
    #sys.exit()
#sys.exit()
print(compare_clim_2d)



    
##--- Call function: smoothing --------------------------------------##
smooth_compare_clim_2d = box_smooth_2D(compare_clim_2d, 1, 1, latitude=np.array(lats))
#smooth_compare_clim_2d = compare_clim_2d


'''
##--- Interpolate ---------------------------------------------------##
nsplines = 4

lons4 = np.linspace(lons2_con[0],lons2_con[-1],lons2_con.shape[0]*nsplines)
lats4 = np.linspace(lats[0],lats[-1],lats.shape[0]*nsplines)
'''

'''
### Use interp2d ###
#//////////////////////////////////////////////////
#/// DeprecationWarning
#/// "interp2d" is deprecated in SciPy 1.10 and will be removed in SciPy 1.14.0.
#/// Use `RegularGridInterpolator` instead.
#//////////////////////////////////////////////////
    
f = interpolate.interp2d(lons, lats, smooth_compare_clim_2d, kind='cubic') # The code does not work
smooth_compare_clim_2d_interp = f(lons4,lats4)
'''


'''
### Use RegularGridInterpolator ###
'''
#(721, 1440) -> (2884, 5760)
'''
data_grid = tuple([lons2_con, lats])
r = RGI(data_grid, smooth_compare_clim_2d.T, method='linear', bounds_error=False)
lon_new, lat_new = np.meshgrid(lons4, lats4, indexing='ij', sparse=True)
inter_compare_clim_2d = r((lon_new, lat_new))
smooth_compare_clim_2d_interp = inter_compare_clim_2d.T
'''

#print(np.amax(smooth_compare_clim_2d_interp, axis = (0, 1)))


##--- Plot anomaly of total precipitation -----------------------------##
fig = plt.figure(figsize=(11,9))
ax = fig.add_subplot(111)

if area == 'europe':
    m = Basemap(projection='cyl', llcrnrlon=-20, urcrnrlon=50, llcrnrlat=35, urcrnrlat=72, resolution = 'l', fix_aspect = False) # <---Europe
elif area == 'global':
    m = Basemap() # <---Global

plt.subplots_adjust(left=0.03, right=0.97, bottom=0.07, top=0.9)

#m.drawmeridians(np.arange(-180, 180, 10)) # Draw longitude
#m.drawparallels(np.arange(-90, 90, 10)) # Draw latitude
m.drawcoastlines()
m.drawcountries(color = 'grey')



### Create grid data ###
lon_grid, lat_grid = np.meshgrid(lons2_con, lats) # <--- Original grid points
#lon_grid, lat_grid = np.meshgrid(lons4, lats4) # <--- Smoothing & Interpolation
x, y = m(lon_grid, lat_grid)


### Set contour levels ###
if area == 'europe':
    clevs = np.arange(-90,108,18)
elif area == 'global':
    clevs = np.arange(-300,360,60)


### Set colour map ###
cmap = plt.cm.colors.ListedColormap(['#263810', '#605c1a', '#a97731', '#d7a175', '#f1ded1', '#ebe4d1', '#a2b89c', '#549185', '#1e6775', '#093761'])
cmap.set_over('#101e4f')
cmap.set_under('#17230e')

### Draw map ###
cf = m.contourf(x, y, compare_clim_2d, levels = clevs, cmap = cmap, extend='both') # <--- Original grid points
#cf = m.contourf(x, y, smooth_compare_clim_2d_interp, levels = clevs, cmap = cmap, extend='both') # <--- Smoothing & Interpolation
#cf = m.contourf(x, y, smooth_compare_clim_2d, levels = clevs, cmap = cmap, extend='both') # <--- Smoothing

### Add colorbar ###
cbar = m.colorbar(cf, location='bottom', pad=0.1, size='3%')
cbar.set_label('Anomaly (mm)', fontsize = 15)

if area == 'europe':
    cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([-90, -72, -54, -36, -18, 0, 18, 36, 54, 72, 90]))

elif area == 'global':
    cbar.ax.xaxis.set_major_locator(ticker.FixedLocator([-300, -240, -180, -120, -60, 0, 60, 120, 180, 240, 300]))
cbar.ax.tick_params(labelsize=15)


### Add title ###
era5_date = datetime.datetime(year = int(start_year), month = int(start_month), day=1)
era5_month = era5_date.strftime('%B')
era5_year = era5_date.strftime('%Y')

fig_title = 'ERA5 Precipitation Anomaly ' + era5_month + ' ' + era5_year
#ax_title = '(Produced: ' + era5_start_date +')'
ax_title = '(Produced: ' + era5_month + ' ' + era5_year + ')'
fig.suptitle(fig_title, fontsize = 18)
ax.set_title(ax_title, loc="right", fontsize = 16)




    
### Save figures ###
figs_dir = out_fig_dir + '/' + str(start_year) + str(start_month) + '/'
if not os.path.exists(figs_dir):
    os.makedirs(figs_dir)


### File name ###
fig_name1 = figs_dir + 'era5_precipitation_anomaly_' + area + '_s' + issue_date + '.png' # <--- archiving figure
fig.savefig(fig_name1, dpi=300, format='png')
    

plt.show()


##--- Close netCDF file ---------------------------------------##
#nc.close()
#nc_clim.close()


print('')
print('Completed!!')
