import os
import calendar
import glob
from netCDF4 import Dataset
import numpy as np
import cftime
import datetime
import time
import dask.array as da
import dask.distributed
import xesmf as xe


def regrid_field(field):
    # Input grid information.
    rlat_in = 180/95
    rlon_in = 2.5
    res_in = (rlon_in, rlat_in)
    lon_bounds_in = (0 - rlon_in/2, 357.5 + rlon_in/2)
    lat_bounds_in = (-90 - rlat_in/2 , 90 + rlat_in/2)

    # Output grid information.
    res_out = (1, 1)
    lon_bounds_out = (-0.5, 359.5)
    lat_bounds_out = (-90.5, 90.5)

    method = 'conservative'

    # NOTE: Useful when method != 'conservative', and grid is global.
    periodic = False

    # Set up input and output grids.
    grid_in = xe.util.grid_2d(lon_bounds_in[0],lon_bounds_in[1], res_in[0],
                              lat_bounds_in[0],lat_bounds_in[1], res_in[1])
    grid_out = xe.util.grid_2d(lon_bounds_out[0], lon_bounds_out[1], res_out[0],
                               lat_bounds_out[0], lat_bounds_out[1], res_out[1])

    regridder = xe.Regridder(grid_in, grid_out, method, periodic=periodic,
                             reuse_weights=True)

    field_out = regridder(field)
    # # Ensure array is masked, converting NaNs or infs to be masked as True.
    # field = np.ma.masked_invalid(field, copy=False)
    #
    # # Set up a mask that is 0 for invalid and 1 for valid values.
    # mask = np.where(field.mask, np.zeros(field.shape),
    #                   np.ones(field.shape))
    # # Regrid the mask.
    # mask = regridder(mask)
    #
    # # Set values that are more than 50% composed of invalid values to np.nan.
    # mask = 1. / np.where(mask > 0.5, mask, np.nan)
    #
    # field_out = regridder(np.where(field.mask, 0. , field)) * mask
    lons = np.array(grid_out.lon)
    lats = np.array(grid_out.lat)

    # Get 1D versions of latitude and longitude points
    lons = lons[0, :]
    lats = lats[:, 0]

    # # Add correction for pole points (in order to regrid to 181 latitudes).
    #
    # inpole0 = field[:, :, 0, :]
    # inpole1 = field[:, :, -1, :]
    #
    # del(field)
    # del(mask)
    #
    # field_out[:, :, 0, :] = np.repeat(np.mean(inpole0, axis=2)[:, :, np.newaxis], 360, axis=2)
    # field_out[:, :, -1, :] = np.repeat(np.mean(inpole1, axis=2)[:, :, np.newaxis], 360, axis=2)

    return field_out, lats, lons

if __name__ == '__main__':
    start_time = time.time()
    dask.distributed.Client()
    from IPython import embed; embed(colors='neutral')

    infile = '/projects/NS9039K/shared/ClimateFutures/raw/norcpm-cf-system1/norcpm-cf-system1_hindcast1/norcpm-cf-system1_hindcast1_19921015/norcpm-cf-system1_hindcast1_19921015_mem01-60/atm/hist/norcpm-cf-system1_hindcast1_19921015_mem01-60.cam2.h2.1992-10-15-21600.nc'
    rootgrp = Dataset(infile, 'r')
    field = rootgrp.variables['TREFHT'][:]
    field02 = da.from_array(field[0:2], chunks=(30, 29, 96, 144))
    # field02 = np.array(field[0:2])
    field_out, lats, lons = regrid_field(field02)

    rootgrp.close()
    print("--- %s seconds ---" % (time.time() - start_time))
