import os, os.path

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
from scipy import ndimage
from scipy.stats import nanmean

import gdal, osr

from nansat import Domain, Nansat, Figure, NSR

class ModisL2ImageSST(Nansat):
    '''
    Read parameters from MERIS file
    '''
    
    def __init__(self, fileName, logLevel=30, **kwargs):
        '''
        Read attributes from file name
        A2012104034000.L2_LAC.NorthNorwegianSeas.hdf
        Read lat/lon
        '''
        #init nansat object (open GDAL dataset, get mapper, etc) add new
        #functions
        ds = gdal.Open(fileName)
        subds = gdal.Open(ds.GetSubDatasets()[0][0])
        isbig = (subds.RasterXSize > 200) * (subds.RasterYSize > 200)

        if not isbig:
            print 'Too little image!'
            time.sleep(0.5)
            raise Exception('Too little image!')
            
        Nansat.__init__(self, fileName, logLevel=logLevel, mapperName='obpg_l2', GCP_COUNT=30)

        #set values to the metadata data where it is availabe for ImagesCatalog
        self.set_metadata('name',       "'%s'" % self.name)
        self.set_metadata('path',       "'%s'" % self.path)
        self.set_metadata('sensor',     "'modis_sst'")
        self.set_metadata('sensstart',  "'%s'" % self.get_time()[0].strftime('%Y-%m-%d %H:%M:%S.%f'))
        self.set_metadata('border',     self.get_border_postgis())
        
        # get solar zenith angle from HDF or aux NPY file
        self.solz = self.get_solz()
        
        # add mask
        self.add_mask()

        self.vrt.remove_geolocationArray()
        
        # use TPS for reprojection
        self.vrt.tps = True

        # reproject GCPs
        gcps = self.vrt.dataset.GetGCPs()
        clon = 0
        clat = 0
        k = 0
        for gcp in gcps:
            clon += gcp.GCPX
            clat += gcp.GCPY
            k += 1
        clon /= k
        clat /= k
        srs = '+proj=stere +datum=WGS84 +ellps=WGS84 +lon_0=%f +lat_0=%f +no_defs' % (clon, clat)
        print srs
        self.vrt.reproject_GCPs(srs)
        
    def process_sst(self, opts=None):
        '''L2-SST processing: reproject, write SST, SSTd products'''
        if opts is None:
            opts = {'oDir': './',
                    'srs': '+proj=stere +datum=WGS84 +ellps=WGS84 +lat_0=73 +lon_0=23 +no_defs',
                    'ext': '-te -1000000 -1000000 1000000 1000000 -tr 1000 1000',
                    }
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        sstName = opts['oDir'] + oBaseFileName + '_sst.png'
        sstdName = opts['oDir'] + oBaseFileName + '_sstd.png'
        sstqName = opts['oDir'] + oBaseFileName + '_ssts.png'
        sstzName = opts['oDir'] + oBaseFileName + '_sstz.png'
        
        # get SST
        sst = self['SST']

        mask = self.add_mask()
        
        # make SSTd
        sst[mask < 64] = np.nan
        
        # get the sobel edge
        sx = ndimage.sobel(sst, axis=0, mode='constant')
        sy = ndimage.sobel(sst, axis=1, mode='constant')
        sstd = np.hypot(sx, sy)
        
        self.add_band(array=sstd, parameters={'name': 'sst_gradient', 'long_name': 'SST gradient'})
        
        d = Domain(opts['srs'], opts['ext'])
        
        self.reproject(d)
        
        mask = self['mask']
        if mask is None:
            return 1
        dateStr = self.get_time()[0].strftime('%Y-%m-%d')
        
        # get reprojected SST
        sstr = self['SST']
        #print 'sstr - OK!'

        # generate SST map
        if not os.path.exists(sstName):
            self.write_figure(sstName, 'SST',
                        clim='hist',
                        ratio=0.9,
                        mask_array=mask,
                        mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]},
                        legend=True, fontSize=46, caption='SST, %s' % dateStr)

        # generate SST gradient map
        if not os.path.exists(sstdName):
            self.write_figure(sstdName, 'sst_gradient',
                        clim='hist',
                        ratio=0.8,
                        mask_array=mask,
                        mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]},
                        legend=True, fontSize=46, caption='SST gradient, %s' % dateStr)

        # generate SST quicklook
        if not os.path.exists(sstqName):
            sstq = sstr[::10, ::10]
            maskq = mask[::10, ::10]
            f = Figure(sstq)
            clim = f.clim_from_histogram(clim='hist',
                        ratio=0.8,
                        mask_array=maskq,
                        mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]})
            f.process(cmin=clim[0], cmax=clim[1])
            f.save(sstqName)

        # get ratio of clear pixels
        clear = int(100. * len(mask[mask == 64]) / mask.shape[0] / mask.shape[1])
        self.set_metadata('clear',      str(clear))
        
        # generate zoomed SST map
        if not os.path.exists(sstzName):
            sstr[mask < 64] = np.nan
            sstf = gaussian_filter(sstr, 2)
            lnz, pnz = np.nonzero(np.isfinite(sstf))
            if len(lnz) > 0 and len(pnz) > 0:
                xOff = pnz.min()
                yOff = lnz.min()
                xSize = pnz.max()-xOff
                ySize = lnz.max()-yOff
                if xSize > 100 and ySize > 100:
                    fontSize = int(xSize * 40./2000)
                    self.crop(xOff, yOff, xSize, ySize)
                    mask = self['mask']
                    self.write_figure(sstzName, 'SST',
                                clim='hist',
                                ratio=0.9,
                                mask_array=mask,
                                mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]},
                                legend=True, fontSize=fontSize, caption='SST, %s' % dateStr)
                    self.undo()
            else:
                self.undo()
                self.resize(0.1, eResampleAlg=0)
                self.write_figure(sstzName, 'SST', clim=[30, 35])

        return 0

    def process_crop_sst(self, opts=None):
        '''L2-SST processing: reproject, write SST, SSTd products'''
        if opts is None:
            opts = {'oDir': './',
                    'lonlim': [0, 30],
                    'latlim': [65, 77],
                    }
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        sstcName = opts['oDir'] + oBaseFileName + '_sstc.png'
        sstcqName = opts['oDir'] + oBaseFileName + '_sstcq.png'
        lonlim = opts['lonlim']
        latlim = opts['latlim']
        
        # crop image
        status = self.crop(lonlim=lonlim, latlim=latlim)
        
        # get cropped SST, mask and date
        if status > 0:
            mask = sst = np.zeros((100,100))
        else:
            sst = self['SST']
            mask = self['mask']
        dateStr = self.get_time()[0].strftime('%Y-%m-%d')
        
        # generate cropped SST map
        if not os.path.exists(sstcName):
            ySize, xSize = sst.shape
            fontSize = int(xSize * 40./2000)
            f = Figure(sst)
            clim = f.clim_from_histogram(ratio=0.99,
                              mask_array=mask,
                              mask_lut={1:[255,255,255], 2:[128,128,128]})
            f.process(cmin=clim[0], cmax=clim[1],
                      legend=True, fontSize=fontSize,
                      caption='SST, %s' % dateStr,
                      cmapName='ak01')
            f.save(sstcName)

        # generate SST quicklook
        if not os.path.exists(sstcqName):
            sstcq = sst[::10, ::10]
            maskq = mask[::10, ::10]
            f = Figure(sstcq)
            f.process(cmin=clim[0], cmax=clim[1],
                      mask_array=maskq,
                      mask_lut={1:[255,255,255], 2:[128,128,128]},
                      cmapName='ak01')
            f.save(sstcqName)

        return 0

    def add_mask(self):
        '''Add band with mask'''
        
        # mask low quality SST
        qual_sst = self['qual_sst']
        mask = np.zeros(qual_sst.shape, 'int8') + 64
        mask[qual_sst > 1] = 1
        
        # mask low resolution data
        reso = self.get_resolution()
        mask[:, reso > 2500] = 1
        
        # mask land
        qual_sst = self['flags']
        mask[np.bitwise_and(qual_sst, np.power(np.uint32(2), 1)) > 0] = 2
        qual_sst = None
        
        # mask low solar zenith as cloud
        #mask[self.solz > 80, :] = 1
        
        # add mask to bands
        self.add_band(array=mask, parameters={'name': 'mask'})

    def get_resolution(self):
        '''Get pixel resolution vector'''
        
        p = [-9.28990979e-19,
              4.81260192e-14,
              -1.85894689e-10,
              2.98320505e-07,
              -2.47788519e-04,
              1.14659973e-01,
              -2.99890467e+01,
              4.82779436e+03]
        if self.shape()[1] > 1300:
            return np.polyval(p, range(self.shape()[1]))
        
        #import pdb; pdb.set_trace()
        lon, lat = self.get_geolocation_grids()
        lon[(lon < -180) * (lon > 180)] = np.nan
        rows, cols = lon.shape
        
        # get 100 random rows from lon lat
        #ri = np.random.permutation(rows0)
        #lon = lon[ri[:200], :]
        #lat = lat[ri[:200], :]
        #rows, cols = lon.shape
        
        clon = nanmean(lon.flat)
        clat = nanmean(lat.flat)
        
        srsString = '+proj=stere +datum=WGS84 +ellps=WGS84 +lon_0=%f +lat_0=%f +no_defs' % (clon, clat)
        
        dstSRS = osr.SpatialReference()
        dstSRS.ImportFromProj4(srsString)
        transformer = osr.CoordinateTransformation(NSR(), dstSRS)
        
        lonlat = np.array([lon, lat])
        lonlat = lonlat.reshape(2, rows*cols)
        lonlat = lonlat.transpose()
        dx = transformer.TransformPoints(lonlat)
        dx = np.array(dx)
        dx = dx.transpose()
        dx0 = dx[0]
        dx1 = dx[1]
        dx0 = dx0.reshape(rows, cols)
        dx1 = dx1.reshape(rows, cols)
        
        dx0 = np.diff(dx0)
        dx1 = np.diff(dx1)
        dx01 = np.sqrt(dx0**2 + dx1**2)
        dx01 = nanmean(dx01, axis=0)

        return dx01

    def add_weight(self):
        ''' Add band with weight (~ resolution'''

        dx = self.get_resolution()
        
        weight = 1/np.exp(dx/2000)
        weight = np.hstack((weight, weight[-1]))
        
        weightArray = np.repeat(np.array([weight]), rows0, axis=0)
        self.add_band(array=weightArray, parameters={'name': 'weight'})

    def get_solz(self):
        ''' Get Solar Zenith angle either from HDF or from aux NPY file'''
        npyFileName = self.fileName + '_solz.npy'

        # read solz from aux NPY file
        if os.path.exists(npyFileName):
            solz = np.load(npyFileName)
            return solz
        
        # try to import pyhdf for reading solz from HDF file
        try:
            from pyhdf import SD
        except:
            return None

        # read solz from the HDF file
        sd = SD.SD(self.fileName)
        solz = sd.select('csol_z')[:]
        
        # save solz to aux NPY file
        try:
            np.save(npyFileName, solz)
        except:
            self.logger.error('Cannot save %s' % npyFileName)
        
        return solz
