from scipy.io import loadmat
from os.path import exists, basename, splitext, dirname
import os
import datetime
from datetime import date
import glob

import calendar

import numpy as np
import matplotlib.pyplot as plt
import gdal, osr

from nansat import Nansat, Domain

class OceanColor:

    def __init__(self, fileName, logLevel=10):

        self.name = basename(fileName)
        self.path = dirname(fileName)

        self.fileName = fileName
        self.resolution = 4

        self.varName =None
        self.wkt = None
        self.wkv = None

        self.oFileName = None
        self.iFileName = None
        self.longName = None
        self.description = None

        # set self.date
        self.set_date()

        self.metadata = {}
        #set values to the attribute data where it is availabe fro ImagesCatalog
        self.set_metadata('name',         "'%s'" % self.name)
        self.set_metadata('path',         "'%s'"   % self.path)
        self.set_metadata('sensstart',    "'%s'" % self.date.strftime('%Y-%m-%d %H:%M:%S.%f'))

        # set metdata 'border'
        if self.name.startswith('npp'):
            srs = '+proj=laea +lon_0=0 +lat_0=90 +datum=WGS84 +ellps=WGS84  +no_defs'
            ext = '-te -2331947 -2331947 2331947 2331947 -ts 933 933'
        else:
            srs = '+proj=latlong +datum=WGS84 +ellps=WGS84 +no_defs'
            ext = '-lle -180 60 180 90 -ts 4500 750'

        self.domain = Domain(srs=srs, ext=ext)
        self.set_metadata('border', self.domain.get_border_postgis())

        # set metadata 'sensor'
        if 'modis' in self.name:
            self.set_metadata('sensor',       "'modis'")
        elif 'seawifs' in self.name:
            self.set_metadata('sensor',       "'seawifs'")
        else:
            raise Exception('the file name is not proper!!')

    def set_metadata(self, key='', value=''):
        self.metadata[key] = value
        return

    def get_metadata(self, key=None):
        metadata = self.metadata

        if key is not None:
            metadata = metadata[key]

        return metadata

    def process_thredds(self, opts=None):
        # set attributes
        self.set_attributes(opts)

        # get dstDomain
        dstDomain = Domain(srs=opts['dstDomain']['srs'],
                           ext=opts['dstDomain']['ext'])

        if exists(self.iFileName):
            m = loadmat(self.iFileName)
            if self.wkv == 'npp':
                array = m[self.wkv].astype(np.float)
            else:
                array = m['lm' + self.wkv].astype(np.float)

            # create temporary nansat object for reprojection only
            tmpN = Nansat(domain=self.domain, array=array)
            tmpN.reproject(dstDomain)

            # get array and replace zeros with nan
            array = tmpN[1]
            array[array <= 0] = np.nan

            parameters={'name': self.varName,
                        'description': self.description,
                        'long_name': self.longName,
                        'wkt': self.wkt}

            # create Nansat from that array
            n = Nansat(domain=dstDomain, array=array,
                       parameters=parameters)

            globalMetadata = opts['metadata']
            if len(opts['additionalMetadata']) != 0:
                globalMetadata = self.set_additionalMetadata(n,
                                                       globalMetadata,
                                                       opts['additionalMetadata'])

            bandMetadata = opts['bandMetadata']
            if bandMetadata['valid_range'] is None:
                bandMetadata['valid_range'] = [float(globalMetadata['valid_min']),
                                               float(globalMetadata['valid_max'])]
            if bandMetadata['units'] is None:
                bandMetadata['units'] = {'doc': 'Kg C m-3',
                                         'tsm': 'Kg C m-3',
                                         'chl': 'g m-3',
                                         'npp': 'mg C m-3 day-1'
                                        }.get(self.wkv)

            # if the file does not exist yet, process ...
            if not os.path.exists(self.oFileName):
                try:
                    n.export2thredds(self.oFileName,
                                     rmMetadata=opts['rmMetadata'],
                                     bands={self.varName:bandMetadata},
                                     metadata=globalMetadata,
                                     time=[self.date])
                except:
                    return 1
        return 0


    def set_additionalMetadata(self, data, metadata, additionalMetadata):
        # get current date
        today = datetime.datetime.now().strftime("%Y-%m-%d")
        # get resolution
        resolution= self.get_resolutions(data.vrt.dataset,
                                         data.vrt.dataset.GetProjection(),
                                         data.shape())
        # get bounds
        WGS84bounds = [-180., 60., 180., 90.]

        for iMeta in additionalMetadata:
            if iMeta in ['date_created', 'date_modified','date_issued']:
                metadata[iMeta] = today
            elif iMeta in ['history']:
                metadata[iMeta] = 'NERSC (Korosov A.) ' + today
            elif iMeta == 'time_coverage_start':
                metadata[iMeta] = '%s-%s-01Z' % (self.date.year, self.date.month)
            elif iMeta == 'time_coverage_end':
                metadata[iMeta] = '%s-%s-%02dZ' %(self.date.year, self.date.month,
                                                  calendar.monthrange(int(self.date.year),
                                                                      int(self.date.month))[1])
            elif iMeta == 'time_coverage_resolution':
                metadata[iMeta] = 'P1M'
            elif iMeta == 'geospatial_lon_resolution':
                metadata[iMeta] = str(resolution[0])
            elif iMeta == 'geospatial_lat_resolution':
                metadata[iMeta] = str(resolution[1])
            elif iMeta == 'geospatial_lon_min':
                metadata[iMeta] = str(WGS84bounds[0])
            elif iMeta == 'geospatial_lat_min':
                metadata[iMeta] = str(WGS84bounds[1])
            elif iMeta == 'geospatial_lon_max':
                metadata[iMeta] = str(WGS84bounds[2])
            elif iMeta == 'geospatial_lat_max':
                metadata[iMeta] = str(WGS84bounds[3])
            elif iMeta == 'valid_min':
                metadata[iMeta] = str(0)
            elif iMeta == 'valid_max':
                metadata[iMeta] = {'chl': '20', 'tsm': '10',
                                   'doc': '10', 'npp': '100000'
                                  }.get(self.wkv)
            elif iMeta == 'keywords':
                metadata[iMeta] = {'chl': 'Oceans > Ocean Chemistry > Marine Geochemistry, Oceans > Ocean Chemistry > Chlorophyll, Oceans > Ocean Chemistry > Pigments > Chlorophyll',
                                   'tsm': 'Oceans > Ocean Chemistry > Marine Geochemistry, Oceans > Ocean Chemistry > Inorganic Matter',
                                   'doc': 'Oceans > Ocean Chemistry > Marine Geochemistry, Oceans > Ocean Chemistry > Carbon, Oceans > Ocean Chemistry > Organic Carbon',
                                   'npp': 'Biosphere > Ecological Dynamics > Ecosystem Functions > Primary Production',
                                  }.get(self.wkv)

        return metadata


    def get_resolutions(self, dataset, srcWKT='', shape=None):
        if srcWKT == '':
            srcWKT = dataset.GetProjection()
        if shape is None:
            shape = (dataset.RasterYSize, dataset.RasterXSize)

        latlongSRS = osr.SpatialReference()
        latlongSRS.ImportFromProj4("+proj=latlong +ellps=WGS84 +datum=WGS84 +no_defs")
        transformer = gdal.Transformer(dataset, None, ['SRC_SRS=' + srcWKT,'DST_SRS=' + latlongSRS.ExportToWkt()])
        #transformer = osr.CoordinateTransformation(srcSRS, latlongSRS)
        points = [[int(shape[1]/2), int(shape[0]-1)],
                 [int(shape[1]/2)+1, int(shape[0]-1)],
                 [int(shape[1]/2), int(shape[0]-1)-1]]

        pointsDegree = []
        for iPoint in points:
            succ, point = transformer.TransformPoint(0, iPoint[0], iPoint[1])
            pointsDegree.append(point)
        dLon = abs(pointsDegree[0][0]-pointsDegree[1][0])
        dLat = abs(pointsDegree[0][1]-pointsDegree[2][1])
        return (dLon, dLat)

    def set_attributes(self, opts):

        longNameDict = {
                        'chl': 'Chlorophyll-a',
                        'tsm': 'Total suspended matter',
                        'doc': 'Dissolved organic carbon',
                        'npp': 'Net Primary Productivity'
                        }

        nppAlgDict = {'Bhr2005': 'bhr05',
                      'Bhr97': 'bhr97',
                      'Marra':'marra'}

        sensor = self.get_metadata('sensor').replace("'", "")

        self.wkv = opts['wkv']
        self.wkt = {
                    'chl': 'mass_concentration_of_chlorophyll_a_in_sea_water',
                    'tsm': 'mass_concentration_of_suspended_matter_in_sea_water',
                    'doc': 'mass_concentration_of_dissolved_organic_carbon_in_sea_water',
                    'npp': 'net_primary_productivity_of_carbon'
                    }.get(self.wkv)

        ocAlg = self.name.split('_')[2]
        if not(ocAlg in ['bhr', 'gsm', 'oc3']):
            ocAlg = 'boreali'

        oDir = opts['threddsDir'] +  '/' + 'arctic4km_' + sensor + '_' + ocAlg
        self.varName = (self.wkv + '_' + sensor + '_' + ocAlg)
        self.description = 'Satellite sensor: %s, Ocean Color Algorithm: %s' % (sensor, ocAlg)
        self.longName = (longNameDict.get(self.wkv) + '(%s/%s)' % (sensor, ocAlg))

        if self.wkv == 'npp':
            nppAlg = self.name.split('_')[3]
            oDir = oDir + ('_' + nppAlgDict.get(nppAlg))
            self.varName =  self.varName + ('_' + nppAlgDict.get(nppAlg))
            self.description = self.description + (', NPP algorithm: %s' % (nppAlg))
            self.longName = self.longName.replace(')', ('/%s)' % nppAlg))

        self.iFileName = (self.path + '/' + self.name)
        self.oFileName = (oDir + '_' + self.wkv + '/' + opts['ofileName']% (
                            self.resolution, self.varName,
                            self.date.year, self.date.month,
                            self.date.year, self.date.month,
                            calendar.monthrange(self.date.year, self.date.month)[1]))

        print 'iFileName: ', self.iFileName
        print 'oFileName: ', self.oFileName

    def set_date(self):

        if self.name.startswith('npp'):
            baseName = splitext(self.name)[0]
            year = int(baseName.split('_')[-1])
            month = int(baseName.split('_')[-2])

        else:
            baseName = self.name.split('_')[2]
            year = int(baseName[3:7])
            month = 12
            days = int(baseName[7:10])
            for iMonth in range(1,12):
                tmpDate0 = datetime.datetime(year, iMonth, 1)
                tmpDate1 = datetime.datetime(year, iMonth+1, 1)
                if (tmpDate0.timetuple().tm_yday <= days
                    and days < tmpDate1.timetuple().tm_yday):
                    month = iMonth
                    break

        self.date = datetime.datetime(year, month, 1)



