"""
Module NSHMBGOneYearModel
"""

__version__ = "$Revision$"
__revision__ = "$Id$"

import os, datetime, ast
import numpy as np
import matplotlib.nxutils as nx

import CSEPFile, CSEPUtils, CSEPGeneric, CSEPLogging
from Forecast import Forecast
from OneYearForecast import OneYearForecast
from RELMCatalog import RELMCatalog
from OneDayModelDeclusInputPostProcess import OneDayModelDeclusInputPostProcess
from DataSourceFactory import DataSourceFactory
from GeoNetNZDataSource import GeoNetNZDataSource
from CSEPInputParams import CSEPInputParams


#-------------------------------------------------------------------------------
#
# NSHMBGOneYear forecast model.
#
# This class is designed to invoke a one-year National Seismic Hazard Map 
# Background forecast model. 
# It prepares input catalog data, and invokes the model. It places forecast file
# under user specified directory.
#
class NSHMBGOneYearModel (OneYearForecast):

    # Static data of the class

    # Keyword identifying type of the class
    Type = "NSHMBG" + OneYearForecast.Type


    # Logger object for the module
    Logger = CSEPLogging.CSEPLogging.getLogger(__name__)
 
    __smoothingDistanceOption = "smoothingDistance"
    
    # Template for 
    # Filename if b-value should be looked up based on polygon, or 
    # constant value of 1.0 should be used
    __bValueLookupOption = "bValue"
    
    #---------------------------------------------------------------------------
    #
    # Helper class to handle b-value lookup if polygon file provided to the 
    # model
    #
    class BValue (object):
        
        # Constant value to use if polygon file is not provided
        __constValue = 1.0
        
        def __init__ (self, lookup_option):
            
            self.allVerts = None
            self.__file = lookup_option
            
            if isinstance(self.__file, str):
                
                self.allVerts = []
                
                # Update list of polygons
                with CSEPFile.openFile(self.__file) as f:
                    for each_row in f:
                        polygon_verts = ast.literal_eval(each_row)
                        self.allVerts.append(np.array(polygon_verts))
                        
                info_msg = "Read polygon file %s: %s" %(self.__file,
                                                        self.allVerts)
                NSHMBGOneYearModel.Logger.info(info_msg)
                        

        #---------------------------------------------------------------------------
        #
        # Return sub-type keyword identifying the model: based on background file 
        # used by the model
        #
        # Input: None.
        #
        # Output:
        #           String identifying the sub-type
        #
        def type (self):
            """ Returns keyword identifying the forecast model sub-type."""
            
            # Capture b-value lookup into forecast filename
            b_value = NSHMBGOneYearModel.BValue.__constValue
            if isinstance(self.__file, str):
                
                # Separate filename from path if given
                file_path, file_name = os.path.split(self.__file)
                b_value = CSEPFile.Name.extension(file_name)
    
            return 'BValue%s' %(b_value)
    
    
        #-----------------------------------------------------------------------
        #
        # Look up b-value based on polygon information if provided
        #
        def lookup(self, cell):
            
            return_value = NSHMBGOneYearModel.BValue.__constValue
            
            if self.allVerts is not None:
                
                polygon_index = None
                
                # Cell center coordinates
                center_lon = cell[Forecast.Format.MinLongitude] + np.abs((cell[Forecast.Format.MaxLongitude] - \
                                                                          cell[Forecast.Format.MinLongitude])/2.0)
                center_lat = cell[Forecast.Format.MinLatitude] + np.abs((cell[Forecast.Format.MaxLatitude] - \
                                                                         cell[Forecast.Format.MinLatitude])/2.0)
                
                # Pass coordinates of cell center to find the polygon
                for each_index, each_polygon in enumerate(self.allVerts):
                    if nx.pnpoly(center_lat, center_lon, each_polygon):
                        polygon_index = each_index
                        break
                        
                if polygon_index is not None:
                    # Use matlab indices (python_index + 1) to look up correct b-value
                    polygon_index += 1
                    
                    if polygon_index in [1, 2, 4, 13, 15]:
                        return_value = 1.11
                    elif polygon_index == 3:
                        return_value = 1.07
                    elif polygon_index in [8, 9, 10, 12, 14]:
                        return_value = 1.01
                    elif polygon_index == 5:
                        return_value = 1.08
                    elif polygon_index in [7, 16, 6, 11]:
                        return_value = 0.96
                    
#                    info_msg = "Using %s index from %s file corresponding to cell [lat=%s,lon=%s]: b-value=%s" \
#                               %(polygon_index, 
#                                 self.__file,
#                                 center_lat,
#                                 center_lon,
#                                 return_value)
#                    NSHMBGOneYearModel.Logger.info(info_msg)
                else:
                    
                    # Generate a warning that default const b-value is used
                    warning_msg = "Could not find polygon from %s file corresponding to cell [lat=%s,lon=%s], using default b-value=%s" \
                                  %(self.__file,
                                    center_lat,
                                    center_lon,
                                    return_value)
                    NSHMBGOneYearModel.Logger.warning(warning_msg)
                    
            
            return return_value


    # Smoothing kernel information
    class KernelParams (object):
        
#        kernel.corrDist = 50; % gaussian kernel distance in km.
#        kernel.maxDist = 800; % max distance to be considered in km (gaussian...)
        def __init__ (self, 
                      smoothing_distance,
                      max_distance=800.0):
            
            # Can be provided through input configuration file
            self.SmoothingDistance = float(smoothing_distance)
            self.MaxDistance = max_distance

    class ForecastInfo (object):
        
        def __init__ (self, forecast_file):
            """ Collect forecast template information used by the model"""
            
            self.np_obj = Forecast.load(forecast_file)
            
            # Get rid of extra columns in loaded template - automatically added
            # by Forecast.load()
            min_index = Forecast.Format.MinLongitude
            max_index = Forecast.Format.MaskBit + 1
            self.np_obj = self.np_obj[:, min_index:max_index]
            
            # forecast.magMin = min(mForeTemp(:,7));
            self.magnitudeMin = (self.np_obj[:, Forecast.Format.MinMagnitude]).min()
            
            # forecast.dM = min(mForeTemp(:,8)) - min(mForeTemp(:,7));
            self.dM = (self.np_obj[:, Forecast.Format.MaxMagnitude] - 
                       self.np_obj[:, Forecast.Format.MinMagnitude]).min()
            
            # forecast.magMax = max(mForeTemp(:,7));
            self.magnitudeMax = self.np_obj[:, Forecast.Format.MinMagnitude].max()
            
            # numMags = length(forecast.magMin:forecast.dM:forecast.magMax);
            self.magnitudeBins = np.arange(self.magnitudeMin,
                                           self.magnitudeMax,
                                           self.dM)
            self.numMagnitudes = self.magnitudeBins.size
            
            # Unique space cells for the template
            self.spaceCells = self.np_obj[::self.numMagnitudes, :].copy()
            

    # This data is static for the class
    __defaultArgs = {__smoothingDistanceOption : "50.0",
                     __bValueLookupOption : None}


    #--------------------------------------------------------------------
    #
    # Initialization.
    #
    # Input: 
    #        dir_path - Directory to store forecast file to.
    #        args - Optional arguments for the model. Default is None.
    # 
    def __init__ (self, dir_path, args=None):
        """ Initialization for NSHMBGOneYearModel class"""

        # Use later than actual download start date for the input catalog
        # to make STEP code happy
        OneYearForecast.__init__(self, 
                                 dir_path,
                                 OneDayModelDeclusInputPostProcess.Type)

        # Input arguments for the model were provided:
        self.__args = CSEPInputParams.parse(NSHMBGOneYearModel.__defaultArgs, 
                                            args)
        
        # b-value polygons if defined for the model
        self.__bValue = NSHMBGOneYearModel.BValue(self.__args[NSHMBGOneYearModel.__bValueLookupOption])
        
        # Initialize kernel parameters
        self.__kernel = NSHMBGOneYearModel.KernelParams(self.__args[NSHMBGOneYearModel.__smoothingDistanceOption])
        
        # Collect information about forecast template for the model
        self.__forecast = NSHMBGOneYearModel.ForecastInfo(CSEPFile.Name.ascii(OneYearForecast.TemplateFile))


    #--------------------------------------------------------------------
    #
    # Return keyword identifying the model.
    #
    # Input: None.
    #
    # Output:
    #           String identifying the type
    #
    def type (self):
        """ Returns keyword identifying the forecast model type."""

        return self.Type


    #---------------------------------------------------------------------------
    #
    # Return sub-type keyword identifying the model: based on background file 
    # used by the model
    #
    # Input: None.
    #
    # Output:
    #           String identifying the sub-type
    #
    def subtype (self):
        """ Returns keyword identifying the forecast model sub-type."""
        
        # Capture kernel.smoothing_distance and b-value lookup into forecast
        # filename
        return 'SD%s%s' %(self.__kernel.SmoothingDistance,
                          self.__bValue.type())


    #---------------------------------------------------------------------------
    #
    # Write input parameter file for the model.
    #
    # Input: None.
    #        
    def writeParameterFile (self,
                            filename = None):
        """ Format input parameter file for the model."""

        # There is no input parameter file for the model
        pass
    

    #---------------------------------------------------------------------------
    #
    # Invoke the model.
    #
    # Input: None
    #        
    def run (self):
        """ Run NSHMBGOneYearModel forecast."""

        # Define spacial cells from forecast template
        # Load declustered catalog ('mCatalogDecl' variable) or 
        # undeclustered catalog ('mCatalog' variable)
        np_catalog = RELMCatalog.load(os.path.join(self.catalogDir,
                                                   self.inputCatalogFilename()))
        
        data_source = DataSourceFactory().object(GeoNetNZDataSource.Type,
                                                 isObjReference = True)
        
        self.timeScale =  (CSEPUtils.decimalYear(self.end_date) - CSEPUtils.decimalYear(self.start_date))
        self.timeScale /= (CSEPUtils.decimalYear(self.start_date) - CSEPUtils.decimalYear(data_source.StartDate))
         
        # Convert events to Cartesian coordinates
        (cart_catalog_lat,
         cart_catalog_lon, \
         cart_catalog_depth) = self.__lla2ecef(np_catalog[:, CSEPGeneric.Catalog.ZMAPFormat.Latitude],
                                               np_catalog[:, CSEPGeneric.Catalog.ZMAPFormat.Longitude],
                                               np_catalog[:, CSEPGeneric.Catalog.ZMAPFormat.Depth])
        
        # Centers of spacial cells for forecast template
        cell_center_lon = self.__forecast.spaceCells[:, Forecast.Format.MinLongitude] + np.abs((self.__forecast.spaceCells[:, Forecast.Format.MaxLongitude] - \
                                                                                         self.__forecast.spaceCells[:, Forecast.Format.MinLongitude])/2.0)
        cell_center_lat = self.__forecast.spaceCells[:, Forecast.Format.MinLatitude] + np.abs((self.__forecast.spaceCells[:, Forecast.Format.MaxLatitude] - \
                                                                                        self.__forecast.spaceCells[:, Forecast.Format.MinLatitude])/2.0)
        cell_center_depth = self.__forecast.spaceCells[:, Forecast.Format.DepthTop] + np.abs((self.__forecast.spaceCells[:, Forecast.Format.DepthBottom] - \
                                                                                       self.__forecast.spaceCells[:, Forecast.Format.DepthTop])/2.0)
        
        # Convert spacial cells to Cartesian coordinates
        (cart_cell_lat,
         cart_cell_lon,
         cart_cell_depth) = self.__lla2ecef(cell_center_lat,
                                            cell_center_lon,
                                            cell_center_depth)
        
        # get the smoothed rate for the learning period for all spacial cells
        self.__gaussNonAdaptiveSmooth(cart_cell_lon,
                                      cart_cell_lat,
                                      cart_cell_depth,
                                      cart_catalog_lon, 
                                      cart_catalog_lat, 
                                      cart_catalog_depth)

        # Get number of target events from input catalog
        # isForMag = mCat(:,6) >= forecast.magMin & mCat(:,6) <= forecast.magMax;
        # [numNondecEvents] = sum(isForMag);        

        selection = ((np_catalog[:, CSEPGeneric.Catalog.ZMAPFormat.Magnitude] >= self.__forecast.magnitudeMin) & \
                     (np_catalog[:, CSEPGeneric.Catalog.ZMAPFormat.Magnitude] <= self.__forecast.magnitudeMax))

        numEvents = selection.sum()
        
        # Normalise the total number so it is consistent with the learning period
        # obsScaledToForecastPeriod = numNondecEvents * ((forecast.end-forecast.start)/(learning.end-learning.start));
        obsScaledToForecastPeriod = numEvents * self.timeScale
        totForecastEvents = self.__forecast.np_obj[:, Forecast.Format.Rate].sum()
        self.__forecast.np_obj[:, Forecast.Format.Rate] *= obsScaledToForecastPeriod/totForecastEvents
        NSHMBGOneYearModel.Logger.info("Total forecast rate=%s" 
                                       %self.__forecast.np_obj[:, Forecast.Format.Rate].sum())
        
        # Save forecast to file
        np.savetxt(self.filename(),
                   self.__forecast.np_obj)
        

    #---------------------------------------------------------------------------
    #
    # Get the smoothed rate for the learning period for the cell
    #
    # Input: None.
    #        
    def __gaussNonAdaptiveSmooth (self,
                                  cells_long,
                                  cells_lat,
                                  cells_depth,
                                  catalog_long,
                                  catalog_lat,
                                  catalog_depth):
        """ Get the smoothed rate for the learning period for the cell"""

        #function[totPeriodRate] = GaussNonAdaptiveSmooth(nodeLoc,cartLoc,kernel)

        gaussDenom = 1.0
        
        totalRate = np.zeros((cells_long.size,),
                             dtype = np.object)

        # For each space cell of the template forecast
        for cell_index in xrange(0, cells_long.size):
            
            cell_lon = cells_long[cell_index]
            cell_lat = cells_lat[cell_index]
            cell_depth = cells_depth[cell_index]
            
            # find distance from this node to all events
            #eventDist = (((repmat(nodeLoc(1),numEvents,1) - cartLoc(:,1)).^2+(repmat(nodeLoc(2),numEvents,1) - cartLoc(:,2)).^2+(repmat(nodeLoc(3),numEvents,1) - cartLoc(:,3)).^2).^.5)/1000;
            dist_value = np.sqrt((cell_lon - catalog_long)**2 + 
                                 (cell_lat - catalog_lat)**2 + 
                                 (cell_depth - catalog_depth)**2)/1000.0
            
            # find events within max distnace to consider                
            selection, = np.where(dist_value <= self.__kernel.MaxDistance)
            events_in_distance = dist_value[selection, :]

            for each_dist in events_in_distance:
                #dcExp = exp((-(abs(eventsInDist(dLoop))).^2)/(2*kernel.corrDist.^2));
                dcExp = np.exp((-(np.abs(each_dist))**2)/(2.0*self.__kernel.SmoothingDistance**2))
                totalRate[cell_index] += gaussDenom*dcExp
                
            # scale the rates to the area of a grid node relative to the kernel
            cell_area = CSEPGeneric.GeoUtils.areaOfRectangularRegion(self.__forecast.spaceCells[cell_index, Forecast.Format.MinLatitude],
                                                                     self.__forecast.spaceCells[cell_index, Forecast.Format.MinLongitude],
                                                                     self.__forecast.spaceCells[cell_index, Forecast.Format.MaxLatitude],
                                                                     self.__forecast.spaceCells[cell_index, Forecast.Format.MaxLongitude])
            
            totalRate[cell_index] *= cell_area
            
            totalRate[cell_index] *= self.timeScale
            
            b_value = self.__bValue.lookup(self.__forecast.spaceCells[cell_index])
            
            magn_rates = self.__magnitudeDistribution(totalRate[cell_index],
                                                      b_value)
            
            # put the mag bins into the correct spot in the forecast matrix
            # mForecast((bLoop-1)*numMags+1:(bLoop)*numMags,9) = DistEvents(1:numMags)';
            bin_start = cell_index * magn_rates.size
            bin_end = bin_start + magn_rates.size
            self.__forecast.np_obj[bin_start:bin_end, Forecast.Format.Rate] = magn_rates
        

        return


    #---------------------------------------------------------------------------
    #
    # Convert latitude, longitude, and altitude to earth-centered, 
    # earth-fixed (ECEF) cartesian
    #
    # From original Matlab code
    #% USAGE:
    #% [x,y,z] = lla2ecef(lat,lon,alt)
    #% 
    #% x = ECEF X-coordinate (m)
    #% y = ECEF Y-coordinate (m)
    #% z = ECEF Z-coordinate (m)
    #% lat = geodetic latitude (radians)
    #% lon = longitude (radians)
    #% alt = height above WGS84 ellipsoid (m)
    #% 
    #% Notes: This function assumes the WGS84 model.
    #%        Latitude is customary geodetic (not geocentric).
    #% 
    #% Source: "Department of Defense World Geodetic System 1984"
    #%         Page 4-4
    #%         National Imagery and Mapping Agency
    #%         Last updated June, 2004
    #%         NIMA TR8350.2
    #% 
    #% Michael Kleder, July 2005
    #        
    def __lla2ecef (self,
                    lat_deg,
                    lon_deg,
                    alt_km,
                    scale = np.pi/180.0):
        """ Convert latitude, longitude, and altitude to earth-centered, 
            earth-fixed (ECEF) cartesian"""
    
        # WGS84 ellipsoid constants:
        a = 6378137.0
        e = 8.1819190842622e-2

        # Apply conversion scale
        lat = lat_deg*scale
        lon = lon_deg*scale
        alt = alt_km*-1.0
        
        # intermediate calculation
        # (prime vertical radius of curvature)
        N = a/np.sqrt(1.0 - e**2 * np.sin(lat)**2)

        
        # results:
        x = (N+alt) * np.cos(lat) * np.cos(lon)
        y = (N+alt) * np.cos(lat) * np.sin(lon)
        z = ((1.0-e**2) * N + alt) * np.sin(lat)
        
        return (x, y, z)
    

    #---------------------------------------------------------------------------
    #
    # Get magnitude distribution within the cell
    #        
    def __magnitudeDistribution (self,
                                 rate,
                                 b_value):
        """ Get magnitude distribution within the cell"""
    
        # function[RateEvents] = DistributeGRMagBinRates(forecast,learning,bvalue,NumEvents)
        data_source = DataSourceFactory().object(GeoNetNZDataSource.Type,
                                                 isObjReference = True)

        
        forecastGRa = np.log10(rate) + (b_value*float(data_source.MinMagnitude))

        mDist = self.__forecast.magnitudeBins

        # R1 = 10.^(forecastGRa-bvalue*mDist);
        # R2 = 10.^(forecastGRa-bvalue*(mDist+forecast.dM));

        R1 = 10.0**(forecastGRa - mDist*b_value)
        R2 = 10.0**(forecastGRa - b_value*(mDist+self.__forecast.dM))

        return (R1 - R2)
    
        