# Compute climatologies of SST from NorCPM output
# Period: 1993-2020 -> Copernicus: 1993-2016 #ver1.1.0
#
#
# Script version 1.1.0
#
# Script works with:
#   Python version 3.10
#
#   Package version
#     numpy: 1.26.3
#     pandas: 2.2.0
#
#
#
# Ver1.0.0: Created by Mariko Koseki, 19.04.2024
# Ver1.1.0: updated by Mariko, 15.05.2024


##--- Import modules ------------------------------##
import netCDF4
import numpy as np
from numpy import dtype
import pandas as pd
import math
import datetime
from datetime import timedelta
#from dateutil.relativedelta import relativedelta
import os
import glob
import sys
import platform
import calendar



print('--- Python version ---')
print(platform.python_version())

print('--- Package version ---')
print('numpy: ', np.__version__)
print('pandas: ', pd.__version__)
print('-------------------------')
print('')



##--- Set path ------------------------------------------##
## Set path to input files ##
'''
Set 'norcpm_monthly' (path to input NetCDF file)
'''
norcpm_monthly = '/nird/projects/NS9873K/www/norcpm/forecast/monthly/'


## Set path to output files ##
'''
Set 'out_clm_dir' (path to output file)
'''
out_dir  = '/nird/projects/NS9873K/norcpm/validation/sst/clim/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)


## Select NorCPM system ##
'''
Set system number 'system_num'
'''
system_num = '1'
#system_num = '2'

print('')
print('system: ', system_num)
print('')




##--- Set year, month ------------------------------------##
'''
Set 'start_year' and 'end_year'
Copernicus: 1993-2016 #ver1.1.0
'''
start_year = 1993 # 1993
end_year = 2017 # 2016
y_str_list = ["{}".format(y) for y in range(start_year,end_year)]
m_str_list = ["{:02}".format(m) for m in range(1,13)]


## Set length of variables ##
nyear = len(y_str_list)
nmonth = len(m_str_list)
nlead = 6
nmem = 60
nlat = 180; nlon = 360

leads = np.arange(1, nlead + 1)
months = np.arange(1, nmonth + 1)


##--- Read netCDF files ----------------------------------##
## Calculate climatology ##
'''
## Example: sea_surface_temperature_bccr_1_2024_04.nc
dimensions:
	number = 60 ;
	time = 6 ;
	latitude = 180 ;
	longitude = 360 ;
variables:
	int time(time) ;
		time:units = "hours since 1900-01-01 00:00:00.0" ;
		time:long_name = "time" ;
		time:calendar = "gregorian" ;
	float latitude(latitude) ;
		latitude:units = "degrees_north" ;
		latitude:long_name = "latitude" ;
	float longitude(longitude) ;
		longitude:units = "degrees_east" ;
		longitude:long_name = "longitude" ;
	int number(number) ;
		number:long_name = "ensemble_member" ;
	float sst(time, number, latitude, longitude) ;
		sst:_FillValue = 1.e+20f ;
		sst:units = "K" ;
		sst:long_name = "Sea surface temperature" ;
		sst:initial_time = "01/04/2024 00:00" ;
'''


sst_clim_4d = np.zeros((nlead, nmonth, nlat, nlon)) #(6, 12, 180, 360)
for lead in range(0, nlead):

    print('lead:', lead)

    sst_clim_3d = np.zeros((nmonth, nlat, nlon)) #(12, 180, 360)
    for mon_num in range(0, nmonth):
        mm = m_str_list[mon_num]
        print('month: ', mm)




        ##--- Calculate average for nyear, 60 members ------------------------------------------## #ver1.1.0
        sst_clim_mean = np.zeros((nyear, nmem, nlat, nlon)) #(nyear, 60, 180, 360)
        for year_num in range(0, nyear):
            yyyy = y_str_list[year_num]

            print('year: ', yyyy)
            nc_sst = norcpm_monthly + yyyy + mm + '/sea_surface_temperature_bccr_' + system_num + '_' + yyyy + '_' + mm + '.nc'
            print('file name: ', nc_sst)
            print('')


            nc = netCDF4.Dataset(nc_sst, 'r')

            '''
            ## Check length of variables ##
            #nlon = len(nc.dimensions['longitude']) # Longitude:
            #nlat = len(nc.dimensions['latitude']) # Latitude:
            #nlead = len(nc.dimensions['time']) # Lead time:
            #nmem = len(nc.dimensions['number']) # Ensemble member:

            #print(lon_len, lat_len, time_len, mem_len)
            '''


            ## Read variables ##
            '''
            time = nc.variables['time'][:] # "hours since 1900-01-01 00:00:00.0"
            timeunits = nc.variables['time'].getncattr('units')
            timecalendar = nc.variables['time'].getncattr('calendar')
            date_sst = netCDF4.num2date(time,timeunits,timecalendar)
            date_sst = [ f'{i.year}-{i.month:02}' for i in date_sst]
            '''

            lons = nc.variables['longitude'][:] # Longitude
            lats = nc.variables['latitude'][:] # Latitude


            for mem in range(0, nmem):
                sst = nc.variables['sst'][lead][mem][:][:]

                ## Convert variables into Numpy ndarray ##
                sst_np = np.array(sst)


                ## Add results into a zeros numpy ndarray ##
                sst_clim_mean[year_num, mem,:,:] = sst_np


            ## Close netCDF file ##
            nc.close()


        mean_2d = np.nanmean(sst_clim_mean, axis=(0, 1)) #(180, 360)
        sst_clim_3d[mon_num,:,:] = mean_2d
        print('file shape of 3D', sst_clim_3d.shape) #(12, 180, 360)


    sst_clim_4d[lead,:,:,:] = sst_clim_3d
    print('file shape of 4D', sst_clim_4d.shape) #(6, 12, 180, 360)


print('')
print('Calculation completed')
print('')
print('------------------------')
print('')



##--- Save results as a new NetCDF file ------------------------------##
print('Save climatology')

## File name ##
outfile = out_dir + 'bccr_' + system_num + '_sea_surface_temperature_clim_' + str(start_year) + '_' + str(end_year-1) + '.nc' #ver1.1.0
nc2 = netCDF4.Dataset(outfile ,'w', format="NETCDF4")

## Define dimensions ##
nc2.createDimension('latitude', nlat)
nc2.createDimension('longitude', nlon)
nc2.createDimension('lead', nlead)
nc2.createDimension('month', nmonth)

## Define variables ##
### latitude ###
lat = nc2.createVariable('latitude', dtype('float32').char, ('latitude',))
lat.units = 'degrees_north'
lat.long_name = 'latitude'
lat[:] = lats

### longitude ###
lon = nc2.createVariable('longitude', dtype('float32').char, ('longitude',))
lon.units = 'degrees_east'
lon.long_name = 'longitude'
lon[:] = lons

### lead time ###
lead = nc2.createVariable('lead', dtype('int32').char, ('lead',))
lead.units = 'lead_month'
lead.long_name = 'lead_time'
lead[:] = leads

### month ###
month = nc2.createVariable('month', dtype('int32').char, ('month',))
month.units = 'month'
month.long_name = 'month'
month[:] = months

### climatology ###
clim_sst = nc2.createVariable('clim_sst', dtype('float32').char, ('lead', 'month', 'latitude', 'longitude',))
clim_sst.units = 'K'
clim_sst.long_name = 'climatology of sea surface temperature'
clim_sst[:,:,:,:] = sst_clim_4d


### close NetCDF file ###
nc2.close()


print('')
print('Completed!!')
