## This file is part of MLPY.
## Resampling Methods.

## Resampling methods returns lists of training/test indexes.
    
## This code is written by Davide Albanese, <albanese@fbk.eu>.
## (C) 2007 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

__all__ = ['kfold', 'kfoldS', 'leaveoneout', 'montecarlo', 'montecarloS',
           'allcombinations', 'manresampling', 'resamplingfile']

import random
import csv
import numpy


def flattenlist(l):
    """
    Flatten a list containing other lists or elements.

    l - list
    """

    res = []
    for elem in l:
        if isinstance(elem, list):
          res += elem
        else:
          res.append(elem)
    return res


def splitlist(l, n):
    """
    Split a list into pieces of roughly equal size.
    
    l - list
    n - number of pieces
    """

    if n > len(l):
        raise ValueError("'n' must be smaller than 'l' length")
    
    splitsize = 1.0 / n * len(l)
    return [ l[int(round(i*splitsize)):int(round((i+1)*splitsize))] for i in range(n) ]


def pncl(cl):
    """
    Return the indexes of positive and negative classes.

    cl - class labels (numpy array 1D integer)
    """

    classes = numpy.unique(cl)
    
    if classes.shape[0] != 2:
        raise ValueError("pncl() works only for two-classes")

    lab = numpy.array(cl)

    pindexes = numpy.where(lab == classes[1])[0]
    nindexes = numpy.where(lab == classes[0])[0]
    
    return pindexes.tolist(), nindexes.tolist()

   
def allcomb(items, k):
    """
    Generator, returns all combinations of items in lists of length k.
    """
    
    if k==0:
        yield []
    else:
        for i in xrange(len(items) - k+1):
            for c in allcomb(items[i+1:], k-1):
                yield [items[i]] + c
  

def kfold(nsamples, sets, rseed = 0, indexes = None):
    """K-fold Resampling Method.

    Input
    
      * *nsamples* - [integer] number of samples
      * *sets*     - [integer] number of subsets (= number of tr/ts pairs)
      * *rseed*    - [integer] random seed
      * *indexes*  - [list integer] source indexes (None for [0, nsamples-1])

    Output
    
      * *idx* - list of *sets* tuples: ([training indexes], [test indexes])
    """

    random.seed(rseed)

    if indexes == None:
        indexes = range(nsamples)
    
    random.shuffle(indexes)

    try:
        subs = splitlist(indexes, sets)    
    except ValueError:
        raise ValueError("'sets' must be smaller than 'nsamples'")

    res = []
    for i in range(sets):
        tr = flattenlist(subs[:i] + subs[i+1:])
        ts = subs[i]
        res.append((tr, ts))
        
    return res
        

def kfoldS(cl, sets, rseed = 0, indexes = None):
    """Stratified K-fold Resampling Method.
    
    Input
    
      * *cl*      - [list (1 or -1)] class label
      * *sets*    - [integer] number of subsets (= number of tr/ts pairs)
      * *rseed*   - [integer] random seed
      * *indexes* - [list integer] source indexes (None for [0, nsamples-1])

    Output
    
      * *idx* - list of *sets* tuples: ([training indexes], [test indexes])
    """

    random.seed(rseed)
    
    pindexes, nindexes = pncl(cl)
       
    if indexes != None:
        pindexes = [indexes[i] for i in pindexes]
        nindexes = [indexes[i] for i in nindexes]

    random.shuffle(pindexes)
    random.shuffle(nindexes)

    try:
        psubs = splitlist(pindexes, sets)
        nsubs = splitlist(nindexes, sets)
    except ValueError:
        raise ValueError("'sets' must be smaller than number of positive samples (%s) and "
                         "than number of negative samples (%s)" % (len(pindexes), len(nindexes)))

    res = []
    for i in range(sets):
        tr = flattenlist(psubs[:i] + psubs[i+1:] + nsubs[:i] + nsubs[i+1:])
        ts = flattenlist(psubs[i] + nsubs[i])
        res.append((tr, ts))
        
    return res
        

def leaveoneout(nsamples, indexes = None):
    """Leave-one-out Resampling Method.
    
    Input
    
      * *nsamples* - [integer] number of samples
      * *indexes*  - [list integer] source indexes (None for [0, nsamples-1])

    Output
    
      * *idx* - list of *nsamples* tuples: ([training indexes], [test indexes])
    """

    if indexes == None:
        indexes = range(nsamples)
    
    res = []
    for i in range(len(indexes)):
        tr = indexes[:i] + indexes[i+1:]
        ts = [indexes[i]]
        res.append((tr, ts))
        
    return res


def montecarlo(nsamples, pairs, sets, rseed = 0, indexes = None):
    """Monte Carlo Resampling Method.
    
    Input
    
      * *nsamples* - [integer] number of samples
      * *pairs*    - [integer] number of tr/ts pairs
      * *sets*     - [integer] 1/(fraction of data in test sets)
      * *rseed*    - [integer] random seed
      * *indexes*  - [list integer] source indexes (None for [0, nsamples-1])

    Output
    
      * *idx* - list of *pairs* tuples: ([training indexes], [test indexes])
    """

    random.seed(rseed)

    if indexes == None:
        indexes = range(nsamples)
    
    res = []
    for i in range(pairs):
        random.shuffle(indexes)
        
        try:
            subs = splitlist(indexes, sets)
        except ValueError:
            raise ValueError("'sets' must be smaller than number of 'nsamples'")

        tr = flattenlist(subs[:-1])
        ts = subs[-1]
        res.append((tr, ts))

    return res


def montecarloS(cl, pairs, sets, rseed = 0, indexes = None):
    """Stratified Monte Carlo Resampling Method.

    Input
    
      * *cl*      - [list (1 or -1)] class label
      * *pairs*   - [integer] number of tr/ts pairs
      * *sets*    - [integer] 1/(fraction of data in test sets)
      * *rseed*   - [integer] random seed
      * *indexes* - [list integer] source indexes  (None for [0, nsamples-1])

    Output
    
      * *idx* - list of *pairs* tuples: ([training indexes], [test indexes])
    """

    random.seed(rseed)

    pindexes, nindexes = pncl(cl)

    if indexes != None:
        pindexes = [indexes[i] for i in pindexes]
        nindexes = [indexes[i] for i in nindexes]

    res = []
    for i in range(pairs):    
        random.shuffle(pindexes)
        random.shuffle(nindexes)

        try:
            psubs = splitlist(pindexes, sets)
            nsubs = splitlist(nindexes, sets)
        except ValueError:
            raise ValueError("'sets' must be smaller than number of positive samples (%s) and "
                             "than number of negative samples (%s)" % (len(pindexes), len(nindexes)))

        tr = flattenlist(psubs[:-1] + nsubs[:-1])
        ts = flattenlist(psubs[-1] + nsubs[-1])
        res.append((tr, ts))

    return res


def allcombinations(cl, sets, indexes = None):
    """All Combinations Resampling Method.

    Input
    
      * *cl*      - [list (1 or -1)] class label
      * *sets*    - [integer] number of subset
      * *indexes* - [list integer] source indexes  (None for [0, nsamples-1])

    Output
    
      * *idx* - list of tuples: ([training indexes], [test indexes])
    """


    nsamples = len(cl)
    pindexes, nindexes = pncl(cl)

    if indexes != None:
        pindexes = [indexes[i] for i in pindexes]
        nindexes = [indexes[i] for i in nindexes]
    else:
        indexes = range(len(cl))
        
    pn, nn  = len(pindexes)/sets, len(nindexes)/sets
    
    if pn < 1 or nn < 1:
        raise ValueError("'sets' must be smaller than number of positive samples (%s) and "
                         "than number of negative samples (%s)" % (len(pindexes), len(nindexes)))

    res = []
    for pts in allcomb(pindexes, pn):
        for nts in allcomb(nindexes, nn):
            tr = indexes[:]
            ts = pts + nts        
            for x in ts:
                tr.remove(x)       
            res.append((tr, ts))

    return res


def manresampling(cl, pairs, trp, trn, tsp, tsn, rseed = 0):
    """Manual Resampling.
    
    Input
    
      * *cl*    - [list (1 or -1)] class label
      * *pairs* - [integer] number of tr/ts pairs
      * *trp*   - [integer] number of positive samples in training
      * *trn*   - [integer] number of negative samples in training
      * *tsp*   - [integer] number of positive samples in test
      * *tsn*   - [integer] number of negative samples in test

    Output
    
      * *idx* - list of *pairs* tuples: ([training indexes], [test indexes])
    """

    
    random.seed(rseed)
    pindexes, nindexes = pncl(cl)
    
    if (trp + tsp) > len(pindexes):
        raise ValueError("'trp' + 'tsp' must be smaller than number of positive samples (%s)" % len(pindexes))
    
    if (trn + tsn) > len(nindexes):
        raise ValueError("'trn' + 'tsn' must be smaller than number of negative samples (%s)" % len(nindexes))
    
    res = []
    for i in range(pairs):
        random.shuffle(pindexes)
        random.shuffle(nindexes)
        trp_idx = pindexes[0:trp]
        tsp_idx = pindexes[trp:trp+tsp]
        trn_idx = nindexes[0:trn]
        tsn_idx = nindexes[trn:trn+tsn]
        tr = trp_idx + trn_idx
        ts = tsp_idx + tsn_idx     
        res.append((tr, ts))
        
    return res


def resamplingfile(nsamples, file, sep = '\t'):
    """Resampling file from file.

    Returns a list of tuples:
    ([training indexes],[test indexes])
    
    Read a file in the form::
    
      [test indexes 'sep'-separated for the first  replicate]
      [test indexes 'sep'-separated for the second replicate]
                              .
                              .
                              .
      [test indexes 'sep'-separated for the last   replicate]

    where indexes must be integers in [0, nsamples-1].

    Input
    
      * *file*     - [string] test indexes file
      * *nsamples* - [integer] number of samples

    Output
    
      * *idx* - list of tuples: ([training indexes],[test indexes])
    """

    
    reader = csv.reader(open(file, "r"), delimiter=sep, lineterminator='\n')
    
    res = []
    for row in reader:

        # read test indexes
        ts_tmp = [int(s) for s in row]
        ts = numpy.unique(ts_tmp).tolist()
        if not len(ts_tmp) == len(ts):
            print "Warning: replicate %s: double values. Fixed" % len(res)
        if len(ts) == 0:
            raise ValueError("Replicate %s: no samples in test" % len(res))    
        
        # build training
        tr_tmp = range(nsamples)
        for idx in ts:
            try:
                tr_tmp.remove(idx)
            except ValueError:
                raise ValueError("Replicate %s: sample %s does not exist" % (len(res), idx))
        tr = tr_tmp
        if len(tr) == 0:
            raise ValueError("Replicate %s: no samples in training" % len(res))    
        
        res.append((tr, ts))

    return res


    
    
    
    
