import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np

ds = xr.open_dataset('/nird/projects/forcing_file.nc')

fig,axs=plt.subplots(2,2,figsize=(12,10),subplot_kw={'projection': ccrs.Orthographic(0,90)})

axs[0,0].scatter(ds.lon_bnds.isel(nv=0).values,ds.lat_bnds.isel(nv=0).values,s=10,c='b',transform=ccrs.PlateCarree())
axs[1,0].scatter(ds.lon_bnds.isel(nv=1).values,ds.lat_bnds.isel(nv=1).values,s=10,c='r',transform=ccrs.PlateCarree())
axs[0,1].scatter(ds.lon_bnds.isel(nv=2).values,ds.lat_bnds.isel(nv=2).values,s=10,c='g',transform=ccrs.PlateCarree())
axs[1,1].scatter(ds.lon_bnds.isel(nv=3).values,ds.lat_bnds.isel(nv=3).values,s=10,c='c',transform=ccrs.PlateCarree())

for ii in range(4):
    aa=axs.flatten()[ii]
    aa.scatter(ds.lon.values,ds.lat.values,s=10,c='k',transform=ccrs.PlateCarree())
    aa.scatter(41.236,51.120,s=10,c='orange',transform=ccrs.PlateCarree())
    aa.gridlines(draw_labels=True)
    aa.coastlines()
    aa.set_extent([40.9,41.6,51.1,51.5], crs=ccrs.PlateCarree())
    aa.set_title('nv='+str(ii),loc='left')

fig.suptitle('Four grid cell corners',fontweight='bold')

plt.savefig('grid_cell_vertices.png')
plt.show()


nv0_lon = ds.lon_bnds.isel(nv=0).values.astype('float64')
nv1_lon = ds.lon_bnds.isel(nv=1).values.astype('float64')
nv2_lon = ds.lon_bnds.isel(nv=2).values.astype('float64')
nv3_lon = ds.lon_bnds.isel(nv=3).values.astype('float64')

nv0_lat = ds.lat_bnds.isel(nv=0).values.astype('float64')
nv1_lat = ds.lat_bnds.isel(nv=1).values.astype('float64')
nv2_lat = ds.lat_bnds.isel(nv=2).values.astype('float64')
nv3_lat = ds.lat_bnds.isel(nv=3).values.astype('float64')


print(nv0_lon.shape)
#(582, 577)
print(nv0_lat.shape)
#(582, 577)

nv_shape = nv0_lat.shape

#V-grid
plat_v = np.zeros((nv_shape[0]+1, nv_shape[1]))
plat_v[:-1,:] = 0.5 * (nv0_lat + nv1_lat)
plat_v[-1, :] = 0.5 * (nv2_lat[-1,:] + nv3_lat[-1,:])

plon_v = np.zeros((nv_shape[0]+1, nv_shape[1]))
plon_v[:-1,:] = 0.5 * (nv0_lon + nv1_lon)
plon_v[-1, :] = 0.5 * (nv2_lon[-1,:] + nv3_lon[-1,:])

#U-grid
plat_u = np.zeros((nv_shape[0], nv_shape[1]+1))
plat_u[:,:-1] = 0.5 *(nv0_lat + nv3_lat)
plat_u[:, -1] = 0.5*(nv1_lat[:,-1] + nv2_lat[:,-1])

plon_u = np.zeros((nv_shape[0], nv_shape[1]+1))
plon_u[:,:-1] = 0.5 *(nv0_lon + nv3_lon)
plon_u[:, -1] = 0.5*(nv1_lon[:,-1] + nv2_lon[:,-1])


print(plat_v.shape)
print(plon_v.shape)
print(plat_u.shape)
print(plon_u.shape)
print(nv0_lat.shape)
print(nv0_lon.shape)

"""
>>> print(plat_v.shape) 
(583, 577)
>>> print(plon_v.shape) 
(583, 577)
>>> print(plat_u.shape) 
(582, 578)
>>> print(plon_u.shape) 
(582, 578)
>>> print(nv0_lat.shape)
(582, 577)
>>> print(nv0_lon.shape)
(582, 577)
>>> ds.lat.shape
(582, 577)
>>> ds.lon.shape
(582, 577)
"""

fig,axs = plt.subplots(2,2,figsize=(12,10),subplot_kw={'projection': ccrs.Orthographic(0,90)})

titles = ['All corners', 'V-grid', 'U-grid', 'Altogether']

for ii in range(4):
    aa=axs.flatten()[ii]
    aa.scatter(ds.lon.values,ds.lat.values,s=10,c='k',transform=ccrs.PlateCarree())
    aa.gridlines(draw_labels=True)
    aa.coastlines()
    aa.set_extent([40.9,41.6,51.1,51.5], crs=ccrs.PlateCarree())
    aa.scatter(41.236,51.120,s=10,c='orange',transform=ccrs.PlateCarree())
    aa.set_title(titles[ii], loc='left')

axs[0,0].scatter(ds.lon_bnds.isel(nv=0).values,ds.lat_bnds.isel(nv=0).values,s=10,c='b',transform=ccrs.PlateCarree())
axs[0,0].scatter(ds.lon_bnds.isel(nv=1).values,ds.lat_bnds.isel(nv=1).values,s=10,c='r',transform=ccrs.PlateCarree())
axs[0,0].scatter(ds.lon_bnds.isel(nv=2).values,ds.lat_bnds.isel(nv=2).values,s=10,c='g',transform=ccrs.PlateCarree())
axs[0,0].scatter(ds.lon_bnds.isel(nv=3).values,ds.lat_bnds.isel(nv=3).values,s=10,c='c',transform=ccrs.PlateCarree())
axs[0,1].scatter(plon_v,plat_v,s=10,c='b',transform=ccrs.PlateCarree())
axs[1,0].scatter(plon_u,plat_u,s=10,c='r',transform=ccrs.PlateCarree())
axs[1,1].scatter(ds.lon_bnds.isel(nv=0).values,ds.lat_bnds.isel(nv=0).values,s=10,c='b',transform=ccrs.PlateCarree())
axs[1,1].scatter(ds.lon_bnds.isel(nv=1).values,ds.lat_bnds.isel(nv=1).values,s=10,c='r',transform=ccrs.PlateCarree())
axs[1,1].scatter(ds.lon_bnds.isel(nv=2).values,ds.lat_bnds.isel(nv=2).values,s=10,c='g',transform=ccrs.PlateCarree())
axs[1,1].scatter(ds.lon_bnds.isel(nv=3).values,ds.lat_bnds.isel(nv=3).values,s=10,c='c',transform=ccrs.PlateCarree())
axs[1,1].scatter(plon_v,plat_v,s=10,c='b',transform=ccrs.PlateCarree())
axs[1,1].scatter(plon_u,plat_u,s=10,c='r',transform=ccrs.PlateCarree())

fig.suptitle('All grid points',fontweight='bold')

plt.show()
plt.savefig('all_grid_points_bottom_right.png')


#Saving new grid points
plat_v_shp_0 = np.arange(plat_v.shape[0])
plat_v_shp_1 = np.arange(plat_v.shape[1])
plon_v_shp_0 = np.arange(plon_v.shape[0])
plon_v_shp_1 = np.arange(plon_v.shape[1]) 
plat_u_shp_0 = np.arange(plat_u.shape[0])
plat_u_shp_1 = np.arange(plat_u.shape[1]) 
plon_u_shp_0 = np.arange(plon_u.shape[0])
plon_u_shp_1 = np.arange(plon_u.shape[1]) 

ds = ds.assign(XLAT_V=(('plat_v_shp_0', 'plat_v_shp_1'), plat_v))
ds = ds.assign(XLONG_V=(('plon_v_shp_0', 'plon_v_shp_1'), plon_v))
ds = ds.assign(XLAT_U=(('plat_u_shp_0', 'plat_u_shp_1'), plat_u))
ds = ds.assign(XLONG_U=(('plon_u_shp_0', 'plon_u_shp_1'), plon_u))

ds['LANDMASK'] = (('rlat', 'rlon'), (ds.FR_LAND.values > 0).astype(float))
ds['LAKEMASK'] = (('rlat', 'rlon'), (ds.FR_LAKE.values > 0).astype(float))

# Check the updated dataset
print(ds)

ds.to_netcdf('all_grid_points_for_scrip_file.nc')






