"""
Module TXStatisticalTest
"""

__version__ = "$Revision$"
__revision__ = "$Id$"

import copy, datetime
import numpy as np

from StatisticalTest import StatisticalTest
from TStatisticalTest import TStatisticalTest
from EvaluationTest import EvaluationTest

import CSEPLogging, CSEP, CSEPFile, MatlabLogical

 


#-------------------------------------------------------------------------------
#
# T statistical test for evaluation of rate-based forecasts across multiple
# forecast classes. 
#
# This class represents T statistical evaluation test for forecasts models
# across multiple forecasts classes.
#
class TXStatisticalTest (TStatisticalTest):

    # Static data
    class XML (object):
        
        # Directory for extra forecast group to be part of the test
        DirAttribute = 'dir'
        
        # Flag to specify if compression is enabled for the group
        CompressionAttribute = 'enableForecastCompression'
        
        # Name of forecast group element for the test
        GroupElement = 'group'
        

    # Keyword identifying the class
    Type = "TX"
  
    __logger = None

    #---------------------------------------------------------------------------
    #
    # Initialization.
    #
    # Input: 
    #        group - ForecastGroup object. This object identifies forecast
    #                models to be evaluated.
    # 
    def __init__ (self,
                  group,
                  args = None):
        """ Initialization for TStatisticalTest class."""

        if TXStatisticalTest.__logger is None:
           TXStatisticalTest.__logger = CSEPLogging.CSEPLogging.getLogger(TXStatisticalTest.__name__)
        
        TStatisticalTest.__init__(self,
                                  group,
                                  args)
        self.__groups = []
        self.__groupsCompression = []
        
        self.__compression = CSEP.Forecast.UseCompression
        
        
    #---------------------------------------------------------------------------
    #
    # Returns keyword identifying the test. Implemented by derived classes.
    #
    # Input: None
    #
    # Output: String representation of the test type.
    #
    def type (self):
        """ Returns test type."""

        return TXStatisticalTest.Type


    def allFiles (self):
        """ Returns list of files for the test."""

        # Use deep copy, otherwise original list for the group gets extended
        forecasts = copy.deepcopy(StatisticalTest.rates(self.forecasts).all_models)
        
        for each_group in self.__groups:
            forecasts.extend(each_group.forecasts.files())
            
        # Return complete list of forecasts and number of forecasts in each group
        return forecasts


    #----------------------------------------------------------------------------
    #
    # Parse test related information from the input configuration file.
    #
    # Input:
    # test_elem - Object that represents test element within input configuration
    #             file
    # init_file - Input configuration file in XML format for forecast group. 
    #
    # Output: None
    #
    def parseConfigFile (self,
                         test_elem,
                         init_file):
        """ Parse test related information from the input configuration file,
            and create evaluation test objects that corresponds to each group
            participating in the test."""
        
        from ForecastGroup import ForecastGroup
        
        # Initialize T-test for other groups
        groups_elems = init_file.children(test_elem,
                                          TXStatisticalTest.XML.GroupElement)
        if len(groups_elems) == 0:
            error_msg = "Forecasts groups should be provided as <%s> elements within %s file for %s test" %(TXStatisticalTest.XML.GroupElement,
                                                                                                            init_file.name,
                                                                                                            self.Type)
            TXStatisticalTest.__logger.error(error_msg)
            raise RuntimeError, error_msg
        
        try:
            for each_group in groups_elems:
                group_dir = each_group.attrib[TXStatisticalTest.XML.DirAttribute]
                
                group_compression = False
                if TXStatisticalTest.XML.CompressionAttribute in each_group.attrib:
                    group_compression = each_group.attrib[TXStatisticalTest.XML.CompressionAttribute]
                    group_compression = bool(int(MatlabLogical.Boolean[group_compression]))
                self.__groupsCompression.append(group_compression)
                    
                group_inputs = each_group.text
                
                # Set compression for the group
                CSEP.Forecast.UseCompression = group_compression
                self.__groups.append(TStatisticalTest(ForecastGroup(group_dir),
                                                      group_inputs))
            # Reset compression flag as it was set for processing    
            CSEP.Forecast.UseCompression = self.__compression
            
        except:
            # In case of an exception, archive forecasts groups to release their locks
            for each_group in self.__groups:
                each_group.forecasts.archive(datetime.datetime.now().date())                  


    #---------------------------------------------------------------------------
    #
    # Invoke evaluation test for the forecast
    #
    # Input: 
    #        forecast_name - Forecast model to test
    #
    def evaluate (self, 
                  forecast_name):
        """ Invoke evaluation test for the forecast."""

        # On very first forecast, prepare the data
        test_name = StatisticalTest.evaluate(self, forecast_name)
        
        # Evaluation test should not be invoked (observation catalog is invalid)
        if test_name is None:
            return

        try:
            for group_i, each_group_test in enumerate(self.__groups):
                # If number of forecasts files have changed, re-initialize internal 
                # data to correspond to current number of forecasts files for 
                # evaluation
    
                # Set group compression for scanning for the files
                CSEP.Forecast.UseCompression = self.__groupsCompression[group_i]
                
                num_models = len(each_group_test.files())
                group_rates = StatisticalTest.rates(each_group_test.forecasts)
                
                if len(group_rates.all_models) != num_models or \
                   not each_group_test.isInitialized():
                    
                    
                    # Number of forecasts files has changed since object creation time
                    each_group_test.initializeData(each_group_test.files())
                
                # Create event information based on observed events for the group
                each_group_test.cumulativeCatalogFile.intermediateObj = self.eventsInfo()
                each_group_test.testDate = self.testDate
                each_group_test.prepareForecasts()
            
            CSEP.Forecast.UseCompression = self.__compression    
    
            # This is last forecast within the current group, invoke evaluation test
            # for the extra groups and combine results into single matrix for
            # evaluation
            group_models = StatisticalTest.rates(self.forecasts).all_models
            num = len(group_models)
            if group_models.index(forecast_name) == (num - 1):
                
                # Allocate result matrices to represent all participating forecasts
                all_files = self.allFiles()
                num_models = len(all_files)
                
                self.np_mnmat = np.zeros((num_models, num_models),
                                         dtype = np.float)
                self.np_lower = np.zeros((num_models, num_models),
                                         dtype = np.float)
                self.np_upper = np.zeros((num_models, num_models),
                                         dtype = np.float)
                self.np_Nmat = np.zeros((num_models, num_models),
                                        dtype = np.float)
                
                # Combine all forecasts information into single matrix for evaluation
                np_sum_rates = np.zeros((num_models, num_models),
                                        dtype = np.float)
                
                # Copy group's values
                group_rates = StatisticalTest.rates(self.forecasts)
                mask_rows, = group_rates.np_masks[0].shape
                np_masks = np.zeros((num_models, mask_rows),
                                    dtype = np.float)
                np_sum_rates_per_bin = np.zeros((num_models, mask_rows),
                                                dtype = np.float)
                
                np_sum_rates[0:num, 0:num] = group_rates.np_sum_rates
                np_masks[0:num, :] = group_rates.np_masks
                np_sum_rates_per_bin[0:num, :] = group_rates.np_sum_rates_per_bin
    
                #print "SUM_RATES_GROUP:"
                #print group_rates.np_sum_rates
    
                # Combine all forecasts information for observed events into
                # single matrix for evaluation            
                np_event_rates = np.zeros((num_models,),
                                          dtype = np.object)
                
                for i in xrange(0, num):
                    np_event_rates[i] = group_rates.np_event_rates[i]
                
                # Copy the rest of the groups data:
                for each_group_test in self.__groups:
                    rates = StatisticalTest.rates(each_group_test.forecasts)
                    n = len(rates.all_models)
    
                    #print "SUM_RATES:"
                    #print rates.np_sum_rates
                    
                    # i - column index, j - row index, then switch
                    for i in xrange(0, n):
                        # Iterate through all models that are priory to the current model "i",
                        # and populate cumulative rates based on common masking bit
                        i_index = all_files.index(rates.all_models[i])
                        np_masks[i_index] = rates.np_masks[i]
                        np_sum_rates_per_bin[i_index] = rates.np_sum_rates_per_bin[i]
                        
                        for j in xrange(0, i_index):
                            selection = (np_masks[i_index] * np_masks[j]) > 0
                            np_sum_rates[j, i_index] = np_sum_rates_per_bin[j][selection].sum()
                            np_sum_rates[i_index, j] = np_sum_rates_per_bin[i_index][selection].sum()
                            
                            # Append event rates to the group's ones
                            np_event_rates[num + i] = rates.np_event_rates[i]
    
                    # Copy next group rates
                    num_end = num + n
                    np_sum_rates[num:num_end, num:num_end] = rates.np_sum_rates
                    num += n
    
                TXStatisticalTest.__logger.info("TX-Test: np_sum_rates: %s" %np_sum_rates)
                TXStatisticalTest.__logger.info("TX-Test: np_event_rates: %s" %np_event_rates)
    
                group_rates.all_models = all_files
                group_rates.np_sum_rates = np_sum_rates
                group_rates.np_event_rates = np_event_rates
                
                # Invoke T-test for each models
                for each_forecast in group_rates.all_models:
                    self._invoke(each_forecast)
            
                # Write data to the file
                test_results = {TStatisticalTest._meanInformationGain: self.np_mnmat,
                                TStatisticalTest._lowerConfidenceLimits: self.np_lower,
                                TStatisticalTest._upperConfidenceLimits: self.np_upper,
                                TStatisticalTest._numberEvents: self.np_Nmat}
                
                #print "RESULTS:", test_results
                
                #print "mnMAT:", self.np_mnmat, "npLOWER:", self.np_lower, "npUPPER:", self.np_upper, "nMAT:", self.np_Nmat
                results = StatisticalTest.Result(test_results)
                results.writeXML(test_name,
                                 tuple(group_rates.all_models),
                                 self.testDir,
                                 self.filePrefix())
        finally:
            
            #  "Archive" forecasts groups to unlock their 'forecasts' directories
            for each_group in self.__groups:
                each_group.forecasts.archive(self.testDate)                  
        
