"""
Module WXStatisticalTest
"""

__version__ = "$Revision$"
__revision__ = "$Id$"

import numpy as np
import os, glob, re

from StatisticalTest import StatisticalTest
from WStatisticalTest import WStatisticalTest
from TXStatisticalTest import TXStatisticalTest
from EvaluationTest import EvaluationTest

import CSEPLogging, CSEP, CSEPFile


#-------------------------------------------------------------------------------
#
# W statistical test for evaluation of rate-based forecasts across multiple
# forecast classes. 
#
# This class represents W statistical evaluation test for forecasts models.
#
class WXStatisticalTest (WStatisticalTest):

    # Static data

    # Keyword identifying the class
    Type = "WX"
    
    __logger = None


    #---------------------------------------------------------------------------
    #
    # Initialization.
    #
    # Input: 
    #        group - ForecastGroup object. This object identifies forecast
    #                models to be evaluated.
    # 
    def __init__ (self,
                  group,
                  args = None):
        """ Initialization for WXStatisticalTest class."""

        if WXStatisticalTest.__logger is None:
           WXStatisticalTest.__logger = CSEPLogging.CSEPLogging.getLogger(WXStatisticalTest.__name__)
        
        WStatisticalTest.__init__(self,
                                  group,
                                  args)


    #---------------------------------------------------------------------------
    #
    # Returns keyword identifying the test. Implemented by derived classes.
    #
    # Input: None
    #
    # Output: String representation of the test type.
    #
    def type (self):
        """ Returns test type."""

        return WXStatisticalTest.Type


    #---------------------------------------------------------------------------
    #
    # Invoke evaluation test for the forecast
    #
    # Input: 
    #        forecast_name - Forecast model to test
    #
    def evaluate (self, 
                  forecast_name):
        """ Invoke evaluation test for the forecast."""

        # Evaluation test should not be invoked (observation catalog is invalid)
        if self.prepareCatalog() is False:
            return
        

        # Since W-test is dependent on T-test, it should always be invoked after
        # the T-test. Current forecast group will already have all other groups,
        # involved into the evaluation, accounted for in StatisticalTest.rates(self.forecasts)
        # since both T and W tests use the same ForecastGroup object.
        # Therefore there is no need to combine sum and observed event rates into
        # single matrix, it has been done by the T-test and stored within
        # StatisticalTest.rates(self.forecasts) static dictionary of the class. 
        # Should invoke test on very first forecast within the group as a result.
        all_files = StatisticalTest.rates(self.forecasts).all_models
        num_models = len(all_files)
        if all_files.index(forecast_name) == 0:
            
            # Allocate result matrices to represent all participating forecasts
            # These arrays are shared by each model evaluation since W-Test is
            # a cumulative test for all models included into forecast group
            self.np_wmat = np.zeros((num_models, num_models),
                                    dtype = np.float)
            self.np_wsigmat = np.zeros((num_models, num_models),
                                       dtype = np.float)
        
            # Invoke W-test for each models
            for each_forecast in all_files:
                self._invoke(each_forecast)
            
            # Write results to the file
            test_results = {WStatisticalTest._WilcoxonSignificance: self.np_wsigmat}
            
            results = StatisticalTest.Result(test_results)
            results.writeXML('%s-%s' %(self.type(),
                                       EvaluationTest.FilePrefix),
                             tuple(StatisticalTest.rates(self.forecasts).all_models),
                             self.testDir,
                             self.filePrefix())                      
        
    #----------------------------------------------------------------------------
    #
    # Plot test results.
    #
    # Input: 
    #        result_file - Path to the result file in XML format
    #        output_dir - Directory to place plot file to. Default is None.
    #
    # Output: 
    #        List of plots filenames.
    #
    @classmethod
    def plot (cls, result_file, output_dir = None):
        """ Plot test results."""
 
        # Extract directory path
        path_str, file_str = os.path.split(result_file) 
         
        ### Create plot per each forecast within result file: W-test results are
        ### plotted on top of T-test results
        
        # Check for existence of TX-test result file
        file_pattern = "%s_%s-%s%s" %(cls.filePrefix(),
                                      TXStatisticalTest.Type,
                                      EvaluationTest.FilePrefix,
                                      CSEPFile.Extension.XML)
        
        t_test_result_file = glob.glob(os.path.join(path_str,
                                                    file_pattern))

        if len(t_test_result_file) == 0:
            t_test_result_file = glob.glob(os.path.join(path_str,
                                                        "*%s*[1-9]" %file_pattern))
        
        if len(t_test_result_file) == 0:
            error_msg = "%s-test result file is required to generate %s-test plot" %(TXStatisticalTest.Type,
                                                                                     WXStatisticalTest.Type) 
            WXStatisticalTest.__logger.error(error_msg)
            raise RuntimeError, error_msg

        # Get DOM object for the result
        doc = StatisticalTest.plot(result_file)

        # Get rid of internal CSEP 'fromXML' keyword from all model names
        models = [re.sub(CSEP.Forecast.FromXMLPostfix,
                         '',
                         each_name) for each_name in doc.elementValue(EvaluationTest.Result.Name).split()]
     
        num_models = len(models)
        
        # Re-store numpy arrays to represent test results
        w_sign_str = doc.elementValue(WStatisticalTest._WilcoxonSignificance)
        w_sign = [float(each_val) for each_val in w_sign_str.split()]
        np_w_sign = np.array([w_sign])
        np_w_sign.shape = (num_models, num_models)
        
        # Information gain plot
        plots = TXStatisticalTest.plot(t_test_result_file[0],
                                       output_dir,
                                       np_w_sign,
                                       result_file)
        
        return plots
        