"""
RegionInfo module
"""

__version__ = "$Revision$"
__revision__ = "$Id$"

import os, re
import numpy as np
from mpl_toolkits.basemap import Basemap
from matplotlib import colors 
import matplotlib.mpl as mpl
#import matplotlib.mlab
import matplotlib.pyplot as plt

import CSEPLogging, CSEPFile, CSEP
from Environment import *
from RateForecastHandler import RateForecastHandler

California = 'California'
SWPacific = 'SWPacific'
NWPacific = 'NWPacific'
Global = 'Global'

#--------------------------------------------------------------------------------
#
# Structure-like class to store variables specific to the CSEP testing region.
# 
class RegionInfo (object):
   
   # Indeces to access area file fields
   __areaMinLatIndex = 0
   __areaMaxLatIndex = 1
   __areaMinLongIndex = 2
   __areaMaxLongIndex = 3        
  
   # For file with grid cells coordinates
   __lonIndex = 0
   __latIndex = 1
   

   def __init__(self,
                collection_area = None, 
                testing_area = None,
                lon_delta = 5.0,
                lat_delta = 2.0,
                map_border = 2.0,
                apply_longitude_conversion = True): 
      
      self.collectionArea = collection_area
      self.testArea = testing_area
      
      self.lonDt = lon_delta
      if isinstance(self.lonDt, str):
          self.lonDt = float(self.lonDt)
          
      self.latDt = lat_delta
      if isinstance(self.latDt, str):
          self.latDt = float(self.latDt)
          
      self.mapBorderDt = map_border
      if isinstance(self.mapBorderDt, str):
          self.mapBorderDt = float(self.mapBorderDt)
        
      # Flag if conversion to [0; 360] degrees should be applied to longitude
      # when positive and negative longitude values exist within forecast 
      self.applyLongitudeConversion = apply_longitude_conversion 

   
   #-----------------------------------------------------------------------------
   #
   # areaCoordinates
   # 
   # Returns lists of minimum and maximum latitude/longitude coordinates that
   # define a testing geographical region.
   #
   # Inputs:
   #         area_file - Area file, default is None meaning that test area of
   #                     of the region should be used.
   #         is_for_map - Flag if area coordinates are requested for mapping 
   #                      purposes. Default is False.
   #    
   def areaCoordinates(self, 
                       area_file = None, 
                       is_for_map = False):  
      """Returns tuple of lists that represent minimum latitude, maximum
         latitude, minimum longitude, and maximum longitude of test area for the
         region."""
          
      min_lat = []
      max_lat = []
      min_lon = []
      max_lon = []

      __area = area_file
      
      # Use test area of the region by default
      if area_file is None and self.testArea is not None:
          __area = CSEPFile.Name.ascii(self.testArea)
          
      # Some regions have testing area set to None 
      if __area is not None:
            
          area_entries = CSEPFile.read(__area)        
        
          __rows, __cols = area_entries.shape
          
          # Grid coordinates are provided (like for California)
          if __cols == 2:
                 
              min_lat = area_entries[:, RegionInfo.__latIndex].min()
              max_lat = area_entries[:, RegionInfo.__latIndex].max()
              min_lon = area_entries[:, RegionInfo.__lonIndex].min()
              max_lon = area_entries[:, RegionInfo.__lonIndex].max()
          
          else: 
              # min/max values for lon/lat are provided (like for WesternPacific)       
           
              min_lat = area_entries[:, RegionInfo.__areaMinLatIndex].tolist()
              max_lat = area_entries[:, RegionInfo.__areaMaxLatIndex].tolist()
              min_lon = area_entries[:, RegionInfo.__areaMinLongIndex].tolist()
              max_lon = area_entries[:, RegionInfo.__areaMaxLongIndex].tolist()

          # Area coordinates should be returned as single values for the purpose
          # of map generation
          if (isinstance(min_lat, list) is True) and (is_for_map is True):
              
              if self.applyLongitudeConversion:
                  # For map display use [0; 360] longitude values if there are
                  # positive and negative values for longitude (SW Pacific region 
                  # within SCEC testing center only)
                  if any(val < 0 for val in min_lon) and \
                     any(val > 0 for val in min_lon):
                      for index in xrange(0, len(min_lon)):
                          if min_lon[index] < 0.0:
                              min_lon[index] += 360.0
                      
                  if any(val < 0 for val in max_lon) and \
                     any(val > 0 for val in max_lon):
                      for index in xrange(0, len(max_lon)):
                          if max_lon[index] < 0.0:
                              max_lon[index] += 360.0
                                    
              min_lat = min(min_lat)
              max_lat = max(max_lat)
              min_lon = min(min_lon)
              max_lon = max(max_lon)
          
      return (min_lat, max_lat, min_lon, max_lon)  


   #-----------------------------------------------------------------------------
   #
   # testAreaBorder
   # 
   # Returns lists of x and y coordinates that define a border for the testing
   # area of the region.
   #
   # Inputs:
   #         area_file - File with area coordinates
   #
   def areaBorders(self, area_file):  
      """Returns tuple of x and y coordinates that represent longitude and 
         latitude coordinates of the border of test area for the region."""

      border_x, border_y = [], []
      
      # Test area is defined for the region
      if area_file is not None:
          
        # Load area coordinates for the testing region
        area = CSEPFile.read(CSEPFile.Name.ascii(area_file))

        __rows, __cols = area.shape
        
        # Grid coordinates are provided (like for California)
        if __cols == 2:
                 
            area_x = area[:, RegionInfo.__lonIndex]
            area_y = area[:, RegionInfo.__latIndex]
            
            # Remove duplicate values
            unique_area_x = area_x.tolist()
    
            __s = set(unique_area_x)
            unique_area_x = list(__s)
            unique_area_x.sort()
            
            unique_area_y = area_y.tolist()
            __s = set(unique_area_y)
            unique_area_y = list(__s)
            unique_area_y.sort()
            
            __xnum = len(unique_area_x)
            __ynum = len(unique_area_y)
            
            __test_area = np.array([False] * __xnum*__ynum)
            __test_area.shape = (__ynum, __xnum)
    
            for each_lon, each_lat in zip(area_x, 
                                          area_y):
                
                __lon_index = unique_area_x.index(each_lon)
                __lat_index = unique_area_y.index(each_lat)
    
                __test_area[__lat_index, __lon_index] = True
            
            left_border = []
            right_border = []
            
            for __lat_index in xrange(0, __ynum):
                prev_cell_mask = False
                
                for __lon_index in xrange(0, __xnum):
                    # Coordinates of the cell
                    __x = __lon_index
                    __y = __lat_index
                    __mask = __test_area[__lat_index, __lon_index]
                    
                    if (__mask and not prev_cell_mask) or \
                       (not __mask and prev_cell_mask):
                        
                        
                        if prev_cell_mask:
                            # Use coordinates of previous cell for polygon point
                            __x -= 1
                        
                        if not prev_cell_mask:
                            left_border.append([unique_area_x[__x],
                                                unique_area_y[__y]])
                        else:
                            right_border.append([unique_area_x[__x],
                                                 unique_area_y[__y]])
    
                    # Within the region
                    prev_cell_mask = __mask
            
            # To have continuous border, need to reverse left side of the border
            # before appending it to the right side
            left_border.reverse()        
            right_border.extend(left_border)
            
            # Make it a closed curve
            right_border.append(right_border[0])
            
            border_x = [i[0] for i in right_border]
            border_y = [i[1] for i in right_border]
            
        else:
            # min/max values for lon/lat are provided (like for WesternPacific)       
           
            min_lat = area[:, RegionInfo.__areaMinLatIndex].tolist()
            max_lat = area[:, RegionInfo.__areaMaxLatIndex].tolist()
            min_lon = area[:, RegionInfo.__areaMinLongIndex].tolist()
            max_lon = area[:, RegionInfo.__areaMaxLongIndex].tolist()
            
            for each_min_lat, each_max_lat, each_min_lon, each_max_lon in zip(min_lat,
                                                                              max_lat,
                                                                              min_lon,
                                                                              max_lon):
                
                # Add border points for rectangular area to the coordinates
                border_x.append(each_min_lon)
                border_y.append(each_min_lat)

                border_x.append(each_min_lon)
                border_y.append(each_max_lat)
                
                border_x.append(each_max_lon)
                border_y.append(each_max_lat)
                
                border_x.append(each_max_lon)
                border_y.append(each_min_lat)
            
                # Make border a closed curve
                border_x.append(each_min_lon)
                border_y.append(each_min_lat)


      return border_x, border_y
    

   #---------------------------------------------------------------------------
   #
   # Get step for grid to be displayed on the map
   #
   # Input: None 
   #
   # Output: Tuple of delta's for longitude and latitude
   #
   def gridDelta (self):
       """ Returns tuple of delta's for longitude and latitude"""

       return (self.lonDt, self.latDt) 


   #---------------------------------------------------------------------------
   #
   # Get testing region grid for the map display
   #
   # Input: None 
   #
   # Output: Tuple of meridians and parallels lists of the the testing grid
   #
   def grid (self, 
             area_file = None, 
             coords = None, 
             lon_dt = 5.0,
             lat_dt = 2.0):
       """ Get testing region grid for the map display."""

       # Use test area of the region by default
       __area = area_file

       min_lat, max_lat, min_lon, max_lon = None, None, None, None
       
       if coords is not None:
           min_lat, max_lat, min_lon, max_lon = coords
       
       else: 
           
           if area_file is None and self.testArea is not None:
               __area = CSEPFile.Name.ascii(self.testArea)
                     
           # Extract region min/max longitude and latitude
           min_lat, max_lat, min_lon, max_lon = self.areaCoordinates(__area,
                                                                     is_for_map = True)


       grid_meridians = [val for val in xrange(int(np.floor(min_lon)), 
                                               int(np.ceil(max_lon)), 1) if val % lon_dt == 0]
       
       if lon_dt < 1.0:
           grid_meridians = np.linspace(int(np.floor(min_lon)), 
                                        int(np.ceil(max_lon)),
                                        num = (int(np.ceil(max_lon)) - int(np.floor(min_lon)))/lon_dt + 1)

       grid_parallels = [val for val in xrange(int(np.floor(min_lat)), 
                                               int(np.ceil(max_lat)), 1) if val % lat_dt == 0]
       
       if lat_dt < 1.0:
           grid_parallels = np.linspace(int(np.floor(min_lat)), 
                                        int(np.ceil(max_lat)),
                                        num = (int(np.ceil(max_lat)) - int(np.floor(min_lat)))/lat_dt + 1)
       

       return (grid_meridians, grid_parallels)
    

   def populateGrid (self, lon, lat, rates):
        """ Populate rectangular grid that bounds testing region with forecast's
           rates"""
           
 
        # Learn cell dimension from forecast data
        lon_dt = np.abs(lon[1] - lon[0])
        lat_dt = np.abs(lat[1] - lat[0])
        
        cell_delta = lon_dt
        if lat_dt > cell_delta:
            cell_delta = lat_dt
        
        # Make coordinate arrays
        lon_grid = np.arange(lon.min(), lon.max()+cell_delta, cell_delta)
        lat_grid = np.arange(lat.min(), lat.max()+cell_delta, cell_delta)
        xi, yi = np.meshgrid(lon_grid, lat_grid)
        
        # Make the grid
        grid = np.empty(xi.shape) * np.nan
        
        nrow, ncol = grid.shape

        # Ignore zeros's from rates
        sel = (rates != 0.0)
        rates = rates[sel]
        lon = lon[sel]
        lat = lat[sel]
        
        # Fill in the grid with forecast rates
        for index in xrange(rates.size):
            
            lon_val = lon[index]    # x coordinate.
            lat_val = lat[index]    # y coordinate.

            # Find the position that lon and lat values correspond to
            sel_vals = np.abs(lon_grid - lon_val)
            lon_ind, = np.where(sel_vals < cell_delta/2.0)
            # print 'lon_ind=', lon_ind, sum(lon_ind), "lon_ind.shape", lon_ind.shape
            lon_ind = lon_ind[0]
             
            
            sel_vals = np.abs(lat_grid - lat_val)
            lat_ind, = np.where(sel_vals < cell_delta/2.0)
            # print 'lat_ind=', lat_ind, sum(lat_ind), "lat_ind.shape", lat_ind.shape
            lat_ind = lat_ind[0]
            
            #print "lon_v=", lon_val, "lon_index=", lon_ind, "lat_v=", lat_val, "lat_index=", lat_ind
            # Fill the bin
            grid[lat_ind][lon_ind] = rates[index]
        
        return xi, yi, grid
    
       
   #----------------------------------------------------------------------------
   #
   # Generate forecast map.
   #
   # Input: 
   #        forecast_file - File path to the forecast in XML format
   #        results_dir - Directory to store result files to. Default is a
   #                      current directory.
   #        test_name - Name of the test for each forecast map is generated.
   #                    Default is None.
   #        catalog_file - Optional observation catalog file to display observed
   #                       events. Default is None.
   #        scale_factor - Scale factor to apply to the forecast. Default value
   #                       is 1.0 (don't scale)
   # Output: 
   #         Filename for the map file
   #
   def createMap (self,
                  image_file,
                  forecast_data_file,
                  catalog_data_file,
                  min_rate_limit = None,
                  max_rate_limit = None): 
        """ Create a map of the forecast model based on selected geographical region."""

        min_lat, max_lat, min_lon, max_lon = self.areaCoordinates(is_for_map = True)

        lon_delta, lat_delta = self.gridDelta()
        
        grid_meridians, grid_parallels = self.grid(coords=(min_lat, max_lat, min_lon, max_lon),
                                                   lon_dt=lon_delta,
                                                   lat_dt=lat_delta)
        
        fig = plt.figure(figsize = (10, 9))
                

        __map = Basemap(llcrnrlon=min_lon - self.mapBorderDt, 
                        urcrnrlon=max_lon + self.mapBorderDt, 
                        llcrnrlat=min_lat - self.mapBorderDt,
                        urcrnrlat=max_lat + self.mapBorderDt, 
                        resolution='f')
        
        __map.fillcontinents(color = 'gray',
                             zorder = 10)

        __map.drawrivers(zorder=20, 
                         color = 'b', 
                         linewidth = 0.5)

        __map.drawstates(zorder=30, 
                         linewidth = 1.0)

        __map.drawcoastlines(zorder=40,
                             linewidth= 2.0)

        __map.drawmapboundary(#fill_color='white',
                              linewidth=2.0,
                              zorder=45)

        # Don't draw longitude on the map 
        __map.drawmeridians(grid_meridians, 
                            labelstyle='+/-', 
                            labels=[1,1,0,1],
                            linewidth=0.0,
                            zorder=50)
 
        # Don't draw latitude on the map
        __map.drawparallels(grid_parallels,
                            labelstyle='+/-', 
                            labels=[1,1,1,1],
                            linewidth=0.0,
                            zorder=60)

        # Access axis
        ax = fig.gca()

        np_forecast = np.loadtxt(forecast_data_file, 
                                 dtype=np.float)
        rates = np_forecast[:, RateForecastHandler.MapReadyFormat.RATE]
        lon = np_forecast[:, RateForecastHandler.MapReadyFormat.LONGITUTE]
        lat = np_forecast[:, RateForecastHandler.MapReadyFormat.LATITUDE]

        if self.applyLongitudeConversion:
            # If there are positive and negative values for longitude,
            # then translate them to [0; 360] range for map display
            if any(val < 0 for val in lon) and \
               any(val > 0 for val in lon):
               sel = lon < 0.0
               lon[sel] += 360

        # Forecasts should not provide zero's for rates, which raise an 
        # exception with LogNorm normalization for colorbar
        sel = rates == 0.0
        rates[sel] = np.nan

        # minimum values for colorbar. filter our nans which are in the grid
        rates_min = rates[np.where(np.isnan(rates) == False)].min()
        rates_max = rates[np.where(np.isnan(rates) == False)].max()
        
        if min_rate_limit:
            rates_min = min_rate_limit
            
        if max_rate_limit:
            rates_max = max_rate_limit

        lon, lat, rates = self.populateGrid(lon, lat, rates)
        
        # Cell centers
        __x, __y = __map(lon, lat)

        cm_name = 'jet'
        color_map = mpl.cm.get_cmap(cm_name)
        
        cm = mpl.cm.ScalarMappable(norm=colors.LogNorm(vmin=rates_min, 
                                                       vmax=rates_max),
                                   cmap=color_map) 
        
        # Make "bad" values transparent
        cm.cmap.set_under('w', alpha=0.0)
        cm.cmap.set_bad('w', alpha=0.0)
        
        # print "Rates:", rates
        __map.pcolormesh(__x, __y,
                         rates,
                         cmap=cm.cmap,
                         norm=cm.norm, 
                         # edgecolor=(1.0, 1.0, 1.0, 0.3),
                         # linewidth=0.000001, 
                         alpha = 0.6,
                         zorder = 75) 

#         colorbar = __map.colorbar(cmesh,
#                                   location='bottom',
#                                   size=0.25,
#                                   pad=0.3)
        colorbar = plt.colorbar(orientation='horizontal',
                                drawedges=False,
                                shrink=0.75,
                                pad=0.05)
        colorbar.set_label("Rate ($\lambda$)")
        
        # It rates range is within same log scale, introduce custom ticks 
        # to the colorbar
        if np.abs(np.log10(rates_max) - np.log10(rates_min)) < 1.0:
            cb_labels = np.linspace(rates_min, rates_max, 5, endpoint=True)
            colorbar.set_ticks(cb_labels)
            colorbar.set_ticklabels(cb_labels)
        
        # Add observed events if any:
        if catalog_data_file:
            
            # Check for an empty file: np.loadtxt() raises an exception if an empty file
            np_catalog = np.fromfile(catalog_data_file)
        
            if np_catalog.size != 0:
                
                
                # There is catalog data to plot
                np_catalog = np.loadtxt(catalog_data_file, 
                                        dtype=np.float)
                
                if np_catalog.ndim == 1:
                   # Re-size inplace (reshape() returns new array object)
                   np_catalog.shape = (1, np_catalog.size)
                                    
                lon = np_catalog[:, 0]
                
                if self.applyLongitudeConversion:
                    # If there are positive and negative values for longitude,
                    # then translate them to [0; 360] range for map display
                    if any(val < 0 for val in lon) and \
                       any(val > 0 for val in lon):
                       sel = lon < 0.0
                       lon[sel] += 360
                
                lat = np_catalog[:, 1]
                mag = np_catalog[:, 2]
                
                sel = mag < 5.0
                mag[sel] *= 2
                
                sel = (mag >= 5.0) & (mag < 6.0)
                mag[sel] *= 4
                
                sel = mag >= 6.0
                mag[sel] *= 6
                
                x, y = __map(lon, lat)
                __map.scatter(x, y, c='r',
                              marker='o', 
                              s=mag, 
                              zorder=80) 

        # Get rid of postfix and prefix in forecast name
        title_name = re.sub(CSEP.Forecast.FromXMLPostfix,
                            '',
                            os.path.basename(forecast_data_file))
        title_name = re.sub(CSEP.Forecast.MapReadyPrefix,
                            '',
                            title_name)
        
        plt.title(CSEPFile.Name.extension(title_name))   
        plt.savefig(image_file)
        plt.close()
        plt.clf()
                        
        return image_file

