import os, os.path
import time
from datetime import datetime, timedelta
import numpy as np
from xml.etree.ElementTree import XML, ElementTree, tostring

import matplotlib.pyplot as plt

import gdal
#from pyhdf import SD

from nansat import Domain, Nansat, Figure

#from boreali import Boreali


class ModisL2Image(Nansat):
    '''
    Read parameters from MERIS file
    '''
    
    def __init__(self, fileName, logLevel=30, mapperName='obpg_l2'):
        '''
        Read attributes from file name
        A2012104034000.L2_LAC.NorthNorwegianSeas.hdf
        Read lat/lon
        '''
        #init nansat object (open GDAL dataset, get mapper, etc) add new
        #functions
        ds = gdal.Open(fileName)
        subds = gdal.Open(ds.GetSubDatasets()[0][0])
        isday = subds.GetMetadata()['Day or Night'] == 'Day'
        isbig = (subds.RasterXSize > 200) * (subds.RasterYSize > 200)
        if not isday:
            print 'Night image'
            time.sleep(0.5)
            raise Exception('Night image!')
        
        if not isbig:
            print 'Too little image!'
            time.sleep(0.5)
            raise Exception('Too little image!')
            
        Nansat.__init__(self, fileName, logLevel=logLevel, mapperName=mapperName)

        #if self.shape()[0] < 200 or self.shape()[1] < 200:
        #    print 'Too little image!'
        #    raise Exception('Too little image!')
        
        #if self.get_metadata('Day or Night') != 'Day':
        #    print 'Night image'
        #    raise Exception('Night image!')
        
        #set values to the metadata data where it is availabe for ImagesCatalog
        self.set_metadata('name',       "'%s'" % self.name)
        self.set_metadata('path',       "'%s'" % self.path)
        self.set_metadata('sensor',     "'modis'")
        self.set_metadata('sensstart',  "'%s'" % self.get_time()[0].strftime('%Y-%m-%d %H:%M:%S.%f'))
        self.set_metadata('border',   self.get_border_postgis())
        self.set_metadata('daily',      '0')
        self.set_metadata('weekly',     '0')
        self.set_metadata('monthly',    '0')
        
        # good bits for amazon plume studying
        #cloudBits=[4, 5, 6, 10, 15, 20, 23, 30]
        # good bits for NRT
        cloudBits = [1, 6, 10, 16, 29]
        #cloudBits=[1, 4, 5, 6, 9, 10, 13, 15, 20, 21, 23, 28, 29, 30]
        # yellow sea:
        #cloudBits = [1, 4, 5, 6, 9, 10, 15, 20, 29]
        #cloudBits=[1, 4, 5, 6, 9, 10, 15, 20, 21, 29]
        
        # for china
        #cloudBits=[1, 5, 6, 9, 10, 13, 15, 20, 23, 28, 29, 30]

        self.add_mask(cloudBits)

        self.vrt.remove_geolocationArray()
        
        # use TPS for reprojection
        self.vrt.tps = True

        # reproject GCPs
        gcps = self.vrt.dataset.GetGCPs()
        clon = 0
        clat = 0
        k = 0
        for gcp in gcps:
            clon += gcp.GCPX
            clat += gcp.GCPY
            k += 1
        clon /= k
        clat /= k
        srs = '+proj=stere +datum=WGS84 +ellps=WGS84 +lon_0=%f +lat_0=%f +no_defs' % (clon, clat)
        print srs
        self.vrt.reproject_GCPs(srs)
        
        
    def add_mask(self, cloudBits, landBits=[2], tmpProdName='latitude'):
        l2c_mask = self.l2c_mask(cloudBits, landBits, tmpProdName)
        self.add_band(array=l2c_mask, parameters={'wkv':'quality_flags', 'name': 'mask'})    

    def l2c_mask(self, cloudBits, landBits=[2], tmpProdName='latitude', invalidTmp=-999):
        '''Create l2c_flags:
        Flag coding:
        1 - cloud   (value = 1)
        2 - land    (value = 2)
        8 - water   (value = 64)

        L2 BITS:
        f01_name=ATMFAIL            #
        f02_name=LAND
        f03_name=PRODWARN           
        f04_name=HIGLINT            #
        f05_name=HILT               #
        f06_name=HISATZEN           #
        f07_name=COASTZ
        f08_name=SPARE
        f09_name=STRAYLIGHT         .#
        f10_name=CLDICE             #
        f11_name=COCCOLITH          .
        f12_name=TURBIDW            .
        f13_name=HISOLZEN           #
        f14_name=SPARE
        f15_name=LOWLW              #
        f16_name=CHLFAIL
        f17_name=NAVWARN
        f18_name=ABSAER
        f19_name=SPARE
        f20_name=MAXAERITER         #
        f21_name=MODGLINT           .#
        f22_name=CHLWARN
        f23_name=ATMWARN            #
        f24_name=SPARE
        f25_name=SEAICE
        f26_name=NAVFAIL
        f27_name=FILTER
        f28_name=SSTWARN            .
        f29_name=SSTFAIL			#
        f30_name=HIPOL              #
        f31_name=PRODFAIL           .
        f32_name=SPARE
        '''

        # get l2_flags (try to change WorkingDataType if VRT is warped)
        try:
            self._modify_warpedVRT2('GDALWarpOptions/WorkingDataType', 'Int32')
        except:
            self.logger.error('Unable to modify WorkingDataType in VRT!')
        
        l2_flags = self['flags']
        
        try:
            self._modify_warpedVRT2('GDALWarpOptions/WorkingDataType', 'Float32')
        except:
            self.logger.error('Unable to modify WorkingDataType in VRT!')

        tmpVar = self[tmpProdName]
                
        # == FOR DEBUG ==
        #import matplotlib.pyplot as plt
        #plt.imshow(l2_flags);plt.colorbar();plt.show()
        #for bit in range(0, 32):
        #    maskTmp = np.bitwise_and(l2_flags, np.power(np.uint32(2), bit))
        #    fName = '%s_%02d.png' % (self.name, bit)
        #    print fName
        #    plt.imsave(fName, maskTmp)
        
        # create l2c_falg matrix with water bit by default
        l2c_mask = 64 * np.ones(l2_flags.shape, np.uint8)
        
        # get solar zenith and mask values below horizon
        #sd = SD.SD(self.fileName)
        #csol_z = a.select('csol_z')[:]
        #l2c_mask[csol_z >= 90] = 0
        
        # process cloud and land masks
        maskLand = np.zeros(l2c_mask.shape)
        maskCloud = np.zeros(l2c_mask.shape)
        # check every bit in a set
        for bit in landBits:
            maskTmp = np.bitwise_and(l2_flags, np.power(np.uint32(2), bit-1))
            maskLand[maskTmp > 0] = 1
        # check every bit in a set
        for bit in cloudBits:
            maskTmp = np.bitwise_and(l2_flags, np.power(np.uint32(2), bit-1))
            maskCloud[maskTmp > 0] = 1
        
        # set cloud bit
        l2c_mask[maskCloud > 0] = 1
        # set land bit
        l2c_mask[maskLand > 0] = 2
        # erase data out of swath
        l2c_mask[tmpVar == 0] = 0
        # erase bad values where latitutde is invalid
        l2c_mask[tmpVar == invalidTmp] = 0
        
        return l2c_mask

    def l2c_flags(self, cloudBits):
        '''Create l2c_flags with values for each l2 flag
        maximum flag is set on top of smaller ones
        
        See l2c_mask for Flag coding
        '''

        # get l2_flags (try to change WorkingDataType if VRT is warped)
        try:
            self._modify_warpedVRT2('GDALWarpOptions/WorkingDataType', 'Int32')
        except:
            pass
        
        l2_flags = self['flags']
        
        try:
            self._modify_warpedVRT2('GDALWarpOptions/WorkingDataType', 'Float32')
        except:
            pass
        
        # create l2c_flags matrix with zero by default
        l2c_flags = np.zeros(l2_flags.shape, np.uint8)
        
        # check every bit in a set
        for bit in cloudBits:
            maskTmp = np.bitwise_and(l2_flags, np.power(np.uint32(2), bit-1))
            l2c_flags[maskTmp > 0] = bit
        
        return l2c_flags

    def _modify_warpedVRT2(self, key, value):
        ''' Modify workingDataType in the warped VRT

        Parameters
        ----------
            workingDataType: string
                desired WorkingDataType 
        Modifies
        --------
            the VRT file which keepes warped vrt is modified
        '''
        # Get XML content from VSI-file
        vsiFileContent = self.vrt.read_xml()

        # Get element from the XML content and modify it
        element = XML(vsiFileContent)
        tree = ElementTree(element)
        elem = tree.find(key)
        elem.text = value

        # Overwrite element
        element = tree.getroot()

        # Write the modified elemements into VSI-file
        self.vrt.write_xml(tostring(element))

    def process_std(self, opts):
        '''Standard L2-processing: only preview generation'''
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        rgbName = opts['oDir'] + oBaseFileName + '_rgb.jpg'
        sstName = opts['oDir'] + oBaseFileName + '_sst.jpg'
        chlName = opts['oDir'] + oBaseFileName + '_chl.jpg'

        self.resize(width = 300, eResampleAlg=0)
        
        # good bits for NRT
        mask = self['mask']
            
        # generate RGB quicklook
        if not os.path.exists(rgbName):
            self.write_figure(rgbName, ['Rrs_678', 'Rrs_555', 'Rrs_443'],
                            clim=[[0,0,0],[0.02, 0.025, 0.016]],
                            mask_array=mask,
                            mask_lut={1:[255,255,255], 2:[128,128,128]})

        # generate SST quicklook
        if not os.path.exists(sstName):
            try:
                self.write_figure(sstName, 'SST',
                            clim=[-5, 20],
                            mask_array=mask,
                            mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]})
            except:
                self.logger.error('No SST in %s' % self.name)

        # generate image with flags
        if not os.path.exists(chlName):
            try:
                self.write_figure(chlName, 'chlor_a',
                            clim=[0, 5],
                            logarithm=True,
                            mask_array=mask,
                            mask_lut={0:[50,50,50], 1:[255,255,255], 2:[128,128,128]})
            except:
                self.logger.error('No chlor_a in %s' % self.name)
            
        return 0
    
    def process_boreali(self, opts):
        '''Advanced processing of MODIS images:
        retrieve chl, tsm, doc with boreali
        generate images
        '''
        
        pnDefaults = {
            'lmchl': [0, 5, False],
            'lmtsm': [0, 3, False],
            'lmdoc': [0, 2, False],
            'lmmse': [1e-8, 1e-5, True]}
        
        borMinMax = [[pnDefaults['lmchl'][0], pnDefaults['lmchl'][1]],
                     [pnDefaults['lmtsm'][0], pnDefaults['lmtsm'][1]],
                     [pnDefaults['lmdoc'][0], pnDefaults['lmdoc'][1]]]
        
        dtsDomain = Domain(opts['srs'], opts['ext'])
        
        fileName = self.get_metadata('name')
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        ncName = opts['oDir'] + oBaseFileName + '.nc'
        print ncName
        prodFileNames = {}
        for pn in opts['prods']:
            prodFileNames[pn] = '%s/%s.%s.png' % (opts['oDir'], oBaseFileName, pn)

        if os.path.exists(ncName):
            print '%s already exist!' % ncName
        else:
            # good bits for NRT
            #self.add_mask(cloudBits=[1, 4, 5, 6, 9, 10, 13, 15, 20, 21, 23, 28, 29, 30])

            try:
                self.reproject(dtsDomain)
            except:
                print 'Cannot reproject %s. Skipping' % fileName
                return 1
            else:
                Rrsw_412 = self['Rrsw_412']
                if Rrsw_412 is None:
                    return 1
                # process input with BOREALI
                b = Boreali(model='northsea', zone='northsea')
                cImg = b.process_lm(self, wavelen=[412, 443, 488, 531, 555, 667],
                                          start=opts['start'],
                                          minmax=borMinMax)
                
                # generate Nansat with results
                img2 = Nansat(domain=self)
                for i, pn in enumerate(opts['prods']):
                    img2.add_band(array=cImg[i, :, :], parameters={'name': pn})
                img2.add_band(array=self['mask'], parameters={'name': 'mask'})
                
                # export results into NC-file
                img2.export(ncName)
                
                # write images with concentrations
                for pn in opts['prods']:
                    pnd = pnDefaults[pn]
                    img2.write_figure(prodFileNames[pn], pn, clim=[pnd[0], pnd[1]], legend=True, logarithm=pnd[2])
    
        return 0
