import os.path
import os

from nansat import Nansat, Domain
from openwind import SARWind

import numpy as np
from datetime import datetime
import dateutil.parser

NCEP_wind_dir = ['/Data/sat/auxdata/model/ncep/gfs/',]

class AsarImage(Nansat):
    '''
    Read parameters from RADARSAT2 file
    '''

    def __init__(self, fileName, logLevel=10):
        '''
        Read attributes from file name
        RS2_20110430_145008_0008_F1_HHHV_SGF_130933_4554_5368385
        Set values to the dictionary data
        '''
        #init nansat object (open GDAL dataset, get mapper, etc) add new
        #functions
        Nansat.__init__(self, str(fileName), mapperName='asar', logLevel=logLevel)

        #set values to the attribute data where it is availabe fro ImagesCatalog
        self.set_metadata('name',         "'%s'" % self.name)
        self.set_metadata('path',         "'%s'"   % self.path)
        self.set_metadata('sensor',       "'asar'")
        self.set_metadata('sensstart',    "'%s'" % self.get_time()[0].strftime('%Y-%m-%d %H:%M:%S.%f'))
        self.set_metadata('border',   self.get_border_postgis())

    def process_web(self, opts=None):
        '''
        NRT Processing of Radarsat includes:
            Visualisation of small quick look
            Visualisation of large quick look
            Generating map
        Input:
        ------
        opts: dictionary with processing options
        '''
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        qlName = opts['mapDir'] + '/' + oBaseFileName + '_.jpg'
        mapName = opts['mapDir'] + oBaseFileName + '_.png'
        """
        # 1. Generate small quicklook
        if not os.path.exists(qlName):
            self.resize(width=300)
            try:
                self.write_figure(qlName, clim='hist', ratio=0.9, cmapName='gray')
            except:
                self.logger.error('Cannot write preview of %s' % self.fileName)
            self.resize()
        """
        # 3. Generate map
        if not os.path.exists(mapName):
            self.write_map(mapName)

        return 0

    def process_wind(self, opts=None):
        ''' Compute wind speed and direction using openwind

        Input:
        ------
        opts: dictionary with processing options
        '''
        oBaseFileName = self.get_metadata('name').strip('"').strip("'")
        oBasePath = self.get_metadata('path').strip('"').strip("'")
        mapName = opts['mapDir'] + oBaseFileName + '_.png'
        resultName = opts['mapDir'] + oBaseFileName + '_.tif'

        if not os.path.exists(mapName):
            sar2wind = self.get_wind()
            if sar2wind is None:
                return 1
            else:
                # if wind direction and speed were calculated, create the map
                if sar2wind.has_band('winddirection') and sar2wind.has_band('windspeed'):
                    title = ('Envisat ASAR Wind Speed ' +
                             sar2wind.get_metadata()['MPH_SENSING_START'].split('.')[0])
                    # Save the map
                    sar2wind.save_wind_map_image(mapName,
                                                 landmask=True,
                                                 drawgrid=False,
                                                 cbar_fontsize=8, vmin=0, vmax=30,
                                                 tight=True, title=title,
                                                 edgecolor=0.0390625,
                                                 quiver_width=0.002,
                                                 quiver_X=0.8,
                                                 quiver_Y=0.1,
                                                 quiver_U=10,
                                                 quiver_label='10 m/sec',
                                                 quiver_fontproperties={'size': 8},
                                                 quiverScaleCriteria = {'%s<=10.0': 200,
                                                                        '10.0<%s<=20.0': 300,
                                                                        '20.0<%s': 400}
                                                 )
                    '''
                    # Create Nansat object with wind speed and direction
                    n = Nansat(domain=sar2wind)
                    # add wind speed band
                    n.add_band(sar2wind['windspeed'],
                               parameters = sar2wind.get_metadata(bandID='windspeed'))

                    # change wind direction "from" to "to"
                    winddirection = np.mod(sar2wind['winddirection'] + 180, 360)
                    winddirParms = sar2wind.get_metadata(bandID='winddirection')
                    winddirParms['long_name'] = 'Wind direction (to)'
                    winddirParms['standard_name'] = 'wind_to_direction'
                    winddirParms['wkv'] = 'wind_to_direction'

                    # add wind direction band
                    n.add_band(winddirection, parameters = winddirParms)

                    # export to GeoTiff
                    n.export(resultName, driver='GTiff')
                    '''
                else:
                    return 1
        return 0

    def process_wind_thredds(self, opts=None):
        sar2wind = self.get_wind()

        if sar2wind is None:
            return 1
        else:
            # get only wind data
            sar2wind.vrt.tps = False
            self.get_wind_only(sar2wind)

        # create file name
        dateString = {}
        for iKey in ['sar_start_date', 'sar_stop_date']:
            dt = datetime.strptime(sar2wind.get_metadata(iKey),
                                    '%d-%b-%Y %H:%M:%S.%f')
            dateString[iKey] =  str('%4d%02d%02dT%02d%02d%02d'
                                     %(dt.year, dt.month, dt.day,
                                       dt.hour, dt.minute, dt.second))
        oFileName = opts['threddsDir']% (dateString['sar_start_date'],
                                         dateString['sar_stop_date'])
        # if the file does not exist yet, process ...
        if not os.path.exists(oFileName):
            # reproject sar2wind onto given domain
            domain = Domain(srs=opts['domain']['srs'],
                            ext=opts['domain']['ext'])

            sar2wind.reproject(domain, use_geolocationArray=False)

            createdTime = None
            if len(opts['additionalMetadata']) != 0:
                metadata, createdTime = self.set_additionalMetadata(sar2wind,
                                                       opts['metadata'],
                                                       opts['additionalMetadata'])
            try:
                time = filter(None, self.get_sar_time(sar2wind,
                                                      timeName=opts['timeName']))

                sar2wind.export2thredds(oFileName,
                                    maskName = opts['maskName'],
                                    rmMetadata=opts['rmMetadata'],
                                    bands=opts['bands'],
                                    metadata=metadata,
                                    time=time,
                                    createdTime=createdTime)
            except:
                return 1

        return 0

    def get_wind(self):
        sar = Nansat(self.fileName)
        sar.vrt.tps = True

        sar.crop(lonlim=[10, 76], latlim=[63, 85])

        ncepFileName = self.get_ncep(self.fileName)
        try:
            ncep = Nansat(ncepFileName)
        except:
            return None
        else:
            # Set vrt.tps for accuracy of reprojection
            ncep.crop(lonlim=[10, 76], latlim=[63, 85])
            ncep.vrt.tps = True

        try:
            sar2wind = SARWind(sar, wind_direction=ncep)
        except:
            return None
        else:
            if (sar2wind.has_band('U') and sar2wind.has_band('V')):
                # add mask
                self.add_mask(sar2wind)

                # reproject GCPs
                gcps = sar2wind.vrt.dataset.GetGCPs()
                clon = 0
                clat = 0
                k = 0
                for gcp in gcps:
                    clon += gcp.GCPX
                    clat += gcp.GCPY
                    k += 1
                clon /= k
                clat /= k
                srs = '+proj=stere +datum=WGS84 +ellps=WGS84 +lon_0=%f +lat_0=%f +no_defs' % (clon, clat)
                sar2wind.vrt.reproject_GCPs(srs)

                return sar2wind
            else:
                return None

    def get_ncep(self, asarFileName):
        # Search for a ncep file in NCEP_wind_dir corresponding to the given sar data

        NCEP_wind_dir = ['/Data/sat/auxdata/model/ncep/gfs/',]

        n = Nansat(asarFileName)
        sigma0_bandNo = n._get_band_number({'standard_name':
                                            'surface_backwards_scattering_coefficient_of_radar_wave',
                                            'polarization': 'VV'})
        SAR_image_time = n.get_time(sigma0_bandNo).replace(tzinfo=None)

        estr = 'Make sure one of these directories are connected:\n\n'
        for dir in NCEP_wind_dir:
            estr = estr + dir + '\n'
            if os.path.exists(dir):
                ndir = dir
                break
        if not 'ndir' in locals():
            raise Exception, estr

        basehour = np.floor((SAR_image_time.hour +
                             SAR_image_time.minute / 60.0 + 3.0 / 2.0)/6.0)*6.0
        basehour = np.min([18, basehour])

        if SAR_image_time.hour + SAR_image_time.minute / 60.0 - basehour > 1.5:
            forecasthour = 3
        else:
            forecasthour = 0

        nfile = os.path.join(ndir,
                             'gfs%4d%02d%02d/' %(SAR_image_time.year,
                                                 SAR_image_time.month,
                                                 SAR_image_time.day),
                             'gfs.t%.2dz.master.grbf%.2d' %(basehour, forecasthour)
                            )

        if os.path.exists(nfile):
            print 'Found ' + nfile + ' file in NERSC'
            ncepFileName = nfile
        else:
            print 'Could not find ' + nfile +' file in NERSC'
            ncepFileName = None

        return ncepFileName

    def add_mask(self, data):
        mask = 64 * np.ones(data.shape()).astype('int8')
        # land mask
        watermask = data.watermask(tps=True)
        mask[watermask[1]==2] = 2
        # no data mask
        mask[data[1]==0.0] = 0
        data.add_band(array=mask, parameters={'wkv':'land_mask',
                                              'name': 'mask',
                                              'long_name': 'L2-mask',
                                              'standard_name': 'mask'})

    def get_wind_only(self, data):
        # remove bands except for U, V and mask
        delBandList = list(np.arange(
                           1, data.vrt.dataset.RasterCount + 1))
        delBandList.remove(data._get_band_number('U'))
        delBandList.remove(data._get_band_number('V'))
        delBandList.remove(data._get_band_number('mask'))
        data.vrt.delete_bands(delBandList)

        # set start and stop date-time to global and band metadata
        dtMetadata = {'sar_start_date': data.get_metadata('SPH_FIRST_LINE_TIME'),
                      'sar_stop_date': data.get_metadata('SPH_LAST_LINE_TIME'),
                      'FillValue': '-10000'
                      }
        data.set_metadata(dtMetadata)
        data.set_metadata(dtMetadata, bandID='U')
        data.set_metadata(dtMetadata, bandID='V')

    def get_sar_time(self, data, bandID=None, timeName='time'):
        time = []
        for i in range(data.vrt.dataset.RasterCount):
            band = data.get_GDALRasterBand(i + 1)

            try:
                time.append(dateutil.parser.parse(
                            band.GetMetadataItem(timeName)))
            except:
                data.logger.debug('Band ' + str(i + 1) + ' has no time')
                time.append(None)

        if bandID is not None:
            bandNumber = data._get_band_number(bandID)
            return time[bandNumber - 1]
        else:
            return time

    def set_additionalMetadata(self, data, metadata, additionalMetadata):
        # get corners of data
        lonCrn,latCrn = data.get_corners()
        # get current time (creation time)
        time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')

        for iMeta in additionalMetadata:
            if (iMeta == 'sar_start_date' or iMeta == 'time_coverage_start' or
                iMeta == 'sar_stop_date' or iMeta == 'time_coverage_end'):
                if iMeta in 'start':
                    keyName = 'sar_start_date'
                else:
                    keyName = 'sar_stop_date'
                dt = datetime.strptime(data.get_metadata(keyName),
                                                    '%d-%b-%Y %H:%M:%S.%f')
                metadata[iMeta] =('%4d-%02d-%02d %02d:%02d:%02d UTC'
                                 %(dt.year, dt.month, dt.day,
                                   dt.hour, dt.minute, dt.second))
            elif iMeta == 'geospatial_lat_min':
                metadata[iMeta] = np.float(min(latCrn))
            elif iMeta == 'geospatial_lat_max':
                metadata[iMeta] = np.float(max(latCrn))
            elif iMeta == 'geospatial_lon_min':
                metadata[iMeta] = np.float(min(lonCrn))
            elif iMeta == 'geospatial_lon_max':
                metadata[iMeta] = np.float(max(lonCrn))
            elif iMeta == 'date_created' or iMeta == 'date_modified':
                metadata[iMeta] = time

        return metadata, time


