# Compute Nino 3.4 index (Nino 3.4 SST anomaly) from model output
#
# Period for climatology: 1993-2016 (Copernicus)
#
#
# Script version 1.2
#
# Script works with:
#   Python version 3.10
#
#   Package version
#     numpy:  
#     pandas: 
#
#
#
# Ver1.0: Created by Mariko Koseki, 03.07.2025
# ver1.1: updated by Mariko, 04.07.2025
# ver1.2: updated by Mariko, 04.07.2025



##--- Import modules ------------------------------##
import netCDF4
import numpy as np
import pandas as pd
import math
import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as ptick
import os
import glob
import sys
import platform
import calendar



print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('-------------------------')
print('')




##--- Input -----------------------------------------##
args = sys.argv
if len(args) == 2:
    date_input = args[1]
    if len(date_input) == 4:
        start_year = int(date_input[0:4])
        end_year = start_year

    else:
        print('---How to use this script---')
        print('python', args[0] ,'<yyyy>')
        print('or')
        print('python', args[0] ,'<yyyy1> <yyyy2>')
        print('<yyyy1> = start year; <yyyy2> = end year')
        print('')
        print('Example: python', args[0] ,'2025')
        print('Example: python', args[0] ,'1995 2023')
        print('')
        sys.exit()

elif len(args) == 3:
    arg1 = args[1]
    arg2 = args[2]
    if (len(arg1) == 4) & (len(arg2) == 4):
        start_year = int(arg1[0:4])
        end_year = int(arg2[0:4])


    else:
        print('---How to use this script---')
        print('python', args[0] ,'<yyyy>')
        print('or')
        print('python', args[0] ,'<yyyy1> <yyyy2>')
        print('<yyyy1> = start year; <yyyy2> = end year')
        print('')
        print('Example: python', args[0] ,'2023')
        print('Example: python', args[0] ,'1995 2023')
        print('')
        sys.exit()


else:
    print('---How to use this script---')
    print('python', args[0] ,'<yyyy>')
    print('or')
    print('python', args[0] ,'<yyyy1> <yyyy2>')
    print('<yyyy1> = start year; <yyyy2> = end year')
    print('')
    print('Example: python', args[0] ,'2024')
    print('Example: python', args[0] ,'1995 2023')
    print('')
    sys.exit()





print('start month', start_year)
print('end year', end_year)
print('')





##--- Functions -----------------------------------------##
def nino34_norcpm_grid(sst_norcpm, lats_norcpm, lons_norcpm):
    '''
    # Compute area average of SST in the Nino 3.4 region
    # Nino 3.4 region (NorCPM grid): 5~-5 (5N-5S), -170~-120 (120W-170W)
    '''
    ## Check index of Nino3.4 area ##
    lat_nino34_start_ind_nor = np.where((lats_norcpm >= -5) & (lats_norcpm <= 5))[0][0]
    lat_nino34_end_ind_nor = np.where((lats_norcpm >= -5) & (lats_norcpm <= 5))[0][-1]
    #print(lat_nino34_start_ind_nor, lat_nino34_end_ind_nor)

    lon_nino34_start_ind_nor = np.where((lons_norcpm >= -170) & (lons_norcpm <= -120))[0][0]
    lon_nino34_end_ind_nor = np.where((lons_norcpm >= -170) & (lons_norcpm <= -120))[0][-1]
    #print(lon_nino34_start_ind_nor, lon_nino34_end_ind_nor) 

    ## Extract Nino3.4 area ##
    nino_sst_norcpm = sst_norcpm[lat_nino34_start_ind_nor:lat_nino34_end_ind_nor + 1, lon_nino34_start_ind_nor:lon_nino34_end_ind_nor +1]
    #print(nino_sst_norcpm.shape)


    ## Compute Nino 3.4 area average of SST ##
    nino34_norcpm = np.nanmean(np.nanmean(nino_sst_norcpm, 0), 0)
    return nino34_norcpm



##--- Set year, month for climatology ------------------------------------##
'''
#Set 'start_year_clm' and 'end_year_clm'
#Copernicus: 1993-2016
'''
start_year_clm = 1993 # 1993
end_year_clm = 2016 # 2016




##--- Set path ------------------------------------------##
## Set path to input files ##
'''
Set path to input NetCDF files
'''
norcpm_monthly = '/nird/projects/NS9873K/www/norcpm/forecast/monthly/'
norcpm_clim  = '/nird/projects/NS9873K/norcpm/validation/sst/clim/bccr_1_sea_surface_temperature_clim_' + str(start_year_clm) + '_' + str(end_year_clm) + '.nc'




## Set path to output files ##
'''
Set 'out_fig_dir' (path to output file)
'''
out_dir  = '/nird/projects/NS9873K/norcpm/validation/enso/norcpm1/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)





##--- Read netCDF files ----------------------------------##
#################
## Climatology ##
#################
print('climatology of NorCPM: ', norcpm_clim)
print('')

nc_norcpm_clim = netCDF4.Dataset(norcpm_clim, 'r')


## Check length of variables ##
nlon_clm = len(nc_norcpm_clim.dimensions['longitude']) # Longitude: 360
nlat_clm = len(nc_norcpm_clim.dimensions['latitude']) # Latitude: 180
nmonth_clm = len(nc_norcpm_clim.dimensions['month']) # month: 12
nlead_clm = len(nc_norcpm_clim.dimensions['lead']) # lead: 6
#print(nlon_clm, nlat_clm, nmonth_clm, nlead_clm)



## Read variables ##
### month ###
months_clm = nc_norcpm_clim.variables['month'][:] # month: 1-12 
#print(months_clm)


### latitude, longitude ###
lons_clm = nc_norcpm_clim.variables['longitude'][:] # Lon: 0.5 - 359.5
lats_clm = nc_norcpm_clim.variables['latitude'][:] # Lat: 89.5 - -89.5
#print(lons_clm, lats_clm)


### SST ###
sst_clim = nc_norcpm_clim.variables['clim_sst'][:][:][:]


## Convert variables into Numpy ndarray ##
sst_clim_np = np.array(sst_clim)
#print(sst_clim_np.shape)






### Call function ###
'''
# Compute area average of SST in the Nino 3.4 region
'''
nino34_norcpm_clm_all = np.zeros((nlead_clm, nmonth_clm))
for mon in range(0, nmonth_clm):

    for lead in range(0, nlead_clm):
        clm_sst_norcpm = sst_clim_np[lead, mon, :, :]
        nino34_nor_clm = nino34_norcpm_grid(clm_sst_norcpm, lats_clm, lons_clm)
        #print('climatology of NorCPM: ', nino34_nor_clm)
        nino34_norcpm_clm_all[lead, mon] = nino34_nor_clm

nino34_norcpm_clm_all = nino34_norcpm_clm_all.T

#print('')
#print('all climatologies of NorCPM: ', nino34_norcpm_clm_all)
#print(nino34_norcpm_clm_all.shape)



### Close netCDF file ###
nc_norcpm_clim.close()



############
## NorCPM ##
############
### Create list of year-month ###
y_str_list = ["{}".format(y) for y in range(start_year,end_year+1)]
m_str_list = ["{:02}".format(m) for m in range(1,13)]


#print(y_str_list, m_str_list)



## Set length of variables ##
nyear = len(y_str_list)
nmonth = len(m_str_list)




nino34idx_all = []
for year_num in range(0, nyear):
    yyyy = y_str_list[year_num]
    print('year: ', yyyy)


    nino34_yearly = []
    for mon_num in range(0, nmonth):
        mm = m_str_list[mon_num]
        print('month: ', mm)



        nc_file_forecast = norcpm_monthly + str(yyyy) + str(mm) + '/sea_surface_temperature_bccr_1_' + str(yyyy) + '_' + str(mm) + '.nc'
        print('')
        print('forecast file name: ', nc_file_forecast)




        if not os.path.exists(nc_file_forecast):
            print('no such a directory!!!')
            sys.exit()

        else:
            print('Data exists!')
            print('')
            

            ## Read netCDF files ##
            nc_sst = netCDF4.Dataset(nc_file_forecast, 'r')


            ## Check length of variables ##
            nlon_sst = len(nc_sst.dimensions['longitude']) # Longitude: 360
            nlat_sst = len(nc_sst.dimensions['latitude']) # Latitude: 180
            nmem_sst = len(nc_sst.dimensions['number']) # Ensemble member: 60
            nlead_sst = len(nc_sst.dimensions['time']) # Lead time: 6


            ## Read variables ##
            ### lead time ###
            time_sst = nc_sst.variables['time'][:] # lead time: 1-6
            timeunits = nc_sst.variables['time'].getncattr('units')
            timecalendar = nc_sst.variables['time'].getncattr('calendar')
            dtime = netCDF4.num2date(time_sst,timeunits,timecalendar)
            dtime = [ f'{i.year}-{i.month:02}-{i.day:02}' for i in dtime]

            ### latitude, longitude ###
            lons_sst = nc_sst.variables['longitude'][:] # Longitude: -179.5 - 179.5
            lats_sst = nc_sst.variables['latitude'][:] # Latitude: 89.5 - -89.5
            mem_sst = nc_sst.variables['number'][:]
            
            print('')
            #print(nlon_sst, nlat_sst, nmem_sst, nlead_sst)
            #print(time_sst, lons_sst, lats_sst)
            print(dtime)
            #print(mem_sst)
            #print('')



            sst_forecast = nc_sst.variables['sst'][:][:][:][:] #time, number, latitude, longitude

            ## Convert variables into Numpy ndarray ##
            sst_forecast_np = np.array(sst_forecast)




            nino34_forecast_all = np.zeros((nlead_sst, nmem_sst)) #(6, 60)
            for lead in range(0, nlead_sst):
                for mem in range(0, nmem_sst):
                    sst_norcpm = sst_forecast_np[lead, mem, :, :]


                    ### Call functions ###
                    '''
                    # Compute area average of SST in the Nino 3.4 region
                    '''
                    nino34_forecast = nino34_norcpm_grid(sst_norcpm, lats_sst, lons_sst)
                    nino34_forecast_all[lead, mem] = nino34_forecast
                    #print(nino34_forecast)
                    #print(nino34_forecast_all.shape)


            #print(nino34_forecast_all)
            

            ### Compute ensemble mean ###
            nino34_forecast_ens_mean = np.mean(nino34_forecast_all, axis=1)
            print('')
            print('ensemble mean: ', nino34_forecast_ens_mean)
            #print('shape of nino34_forecast_all: ', nino34_forecast_all.shape)



            ## Close netCDF file ##
            nc_sst.close()





            ## Calculate percentiles ##
            percentile_10_all = []
            percentile_90_all = []
            for lead in range(0, nlead_sst):
                percentile_10 = np.percentile(nino34_forecast_all[lead, :], 10)
                percentile_90 = np.percentile(nino34_forecast_all[lead, :], 90)
                percentile_10_all.append(percentile_10)
                percentile_90_all.append(percentile_90)

            print('')
            print('10 percentile', percentile_10_all)
            print('90 percentile', percentile_90_all)




            ##--- Compute Nino3.4 index ----------------------------------------------##
            nino34_clm_sst = nino34_norcpm_clm_all[int(mm)-1, :]
            print('climatology: ', nino34_clm_sst)
            

            nino34_idx = nino34_forecast_ens_mean - nino34_clm_sst
            nino34_idx_per10 = percentile_10_all - nino34_clm_sst
            nino34_idx_per90 = percentile_90_all - nino34_clm_sst

            nino34_idx_ens_mem_list = []
            for mem in range(0, nmem_sst):
                nino34_each_mem = nino34_forecast_all[:, mem] - nino34_clm_sst
                nino34_idx_ens_mem_list.append(nino34_each_mem)
            nino34_idx_ens_mem = np.array(nino34_idx_ens_mem_list).T



            #print(nino34_idx.shape)
            #print(nino34_idx_per10.shape)
            #print(nino34_idx_per90.shape)
            #print(nino34_idx_ens_mem.shape)

            #print('')
            #print(nino34_idx)
            #print(nino34_idx_per10)
            #print(nino34_idx_per90)
            #print(nino34_idx_ens_mem)


            ## Convert np array into Pandas DataFrame ##

            leadtime = ['lead1','lead2','lead3','lead4','lead5','lead6']
            all_np = np.column_stack([dtime, leadtime, nino34_idx_ens_mem, nino34_idx, nino34_idx_per10, nino34_idx_per90]) #ver1.1

            #print(all_np)

            df = pd.DataFrame(all_np, columns=["dtime", "lead_time", "ens_mem01", "ens_mem02", "ens_mem03", "ens_mem04", "ens_mem05", \
            "ens_mem06", "ens_mem07", "ens_mem08", "ens_mem09", "ens_mem10", \
            "ens_mem11", "ens_mem12", "ens_mem13", "ens_mem14", "ens_mem15", \
            "ens_mem16", "ens_mem17", "ens_mem18", "ens_mem19", "ens_mem20", \
            "ens_mem21", "ens_mem22", "ens_mem23", "ens_mem24", "ens_mem25", \
            "ens_mem26", "ens_mem27", "ens_mem28", "ens_mem29", "ens_mem30", \
            "ens_mem31", "ens_mem32", "ens_mem33", "ens_mem34", "ens_mem35", \
            "ens_mem36", "ens_mem37", "ens_mem38", "ens_mem39", "ens_mem40", \
            "ens_mem41", "ens_mem42", "ens_mem43", "ens_mem44", "ens_mem45", \
            "ens_mem46", "ens_mem47", "ens_mem48", "ens_mem49", "ens_mem50", \
            "ens_mem51", "ens_mem52", "ens_mem53", "ens_mem54", "ens_mem55", \
            "ens_mem56", "ens_mem57", "ens_mem58", "ens_mem59", "ens_mem60", "ens_mean", "percentile_10", "percentile_90"]) #ver1.1, ver1.2
            
            print('')
            print(df)


            ##--- Save results as CSV --------------------------------------##
            ### Save DataFrame as CSV ###
            df.to_csv(out_dir + 'nino34_idx_norcpm1_' + str(yyyy) + str(mm) + '.csv', index=False)
            print('Saved Nino3.4 index!')
            print('')
            


print('')
print('Completed!!')
print('')
