"""
Classes for reading (TODO: and writing) STARDB files.
"""

from __future__ import division
from __future__ import print_function
from __future__ import with_statement

import numpy as np
import os, sys, gzip, bz2
import utils 
from byte2human import byte2human
from utils import prod, xz_file_size
import isotope
import re
from logged import Logged

from lzma import LZMAFile


class StarDB(Logged):
    """
    Class for reading STARDB binary files.

    Compressed files with .gz will be automatically uncompressed.
    This may fail, however, if the file is bigger than 2GB or 4GB.
    
    Compressed files with .bz2 will be automatically uncompressed.    
    This is currently rather inefficient because to determine the
    file size the entire stream need to be read first.

    FIELD_TYPES:
      0 UNDEFINED Undefined
      1 BYTE Byte
      2 INT Integer
      3 LONG Longword integer
      4 FLOAT Floating point
      5 DOUBLE Double-precision floating
      6 COMPLEX Complex floating
      7 STRING String
      8 STRUCT Structure [NOT ALLOWED]
      9 DCOMPLEX Double-precision complex
     10 POINTER Pointer
     11 OBJREF Object reference
     12 UINT Unsigned Integer
     13 ULONG Unsigned Longword Integer
     14 LONG64 64-bit Integer
     15 ULONG64 Unsigned 64-bit Integer 
    """
    
    sys_is_le = sys.byteorder == "little"
    native_byteorder = "<" if sys_is_le else ">"

    abundance_type_names = ('isomere','isotope','element','mass number (isobar)','neutron number (isotone)')
    abundance_class_names = ('all (raw)','rad (raw + decays)','dec (stable subset of radiso)')
    abundance_unit_names = ('mass (g)','mass (solar)','mol','mass fraction (X)','mol fraction (YPS)','log epsion','[ ]','production factor (10^[])','log [Y/Si] + 6')
    abundance_total_names = ('initial (total, not normalized, fallback is missing mass)','ejecta')
    abundance_data_names = ('all ejecta (SN ejecta + wind)','SN ejecta, no wind','wind only','ejecta including fallback and wind')
    abundance_sum_names = ('mass fraction','number fraction')

    flagnames = ('parameter','property')

    typenames = (
        'UNDEFINED', 
        'BYTE', 
        'INT',
        'LONG',
        'FLOAT',
        'DOUBLE',
        'COMPLEX',
        'STRING',
        'STRUCT', 
        'DCOMPLEX',
        'POINTER',
        'OBJREF',
        'UINT',
        'ULONG',
        'LONG64',
        'ULONG64')


    class SignatureError(Exception):
        """
        Exception raised when signature could not be read.
        """
   
        def __init__(self, filename):
            """
            Store file name that causes error.
            """
            self.filename = filename 
   
        def __str__(self):
            """
            Return error message.
            """
            return "Error reading signature from file {:s}."\
                   .format(self.filename)

    class VersionError(Exception):
        """
        Exception raised when version mismatch.
        """
   
        def __init__(self):
            """
            Just set up.
            """
   
        def __str__(self):
            """
            Return error message.
            """
            return "Version Error.".format()

    class IntegrityError(Exception):
        """
        Exception raised when file integrity seems broken.
        """
   
        def __init__(self):
            """
            Just set up.
            """
   
        def __str__(self):
            """
            Return error message.
            """
            return "File Integrity Error.".format()

    class DataError(Exception):
        """
        Exception raised when data seems faulty.
        """
   
        def __init__(self):
            """
            Just set up.
            """
   
        def __str__(self):
            """
            Return error message.
            """
            return "Data seems faulty.".format()

    def __init__(self,
                 filename,
                 extension = None):
        """
        Initialize data fields and open file.

        Optionally the byte order can be specified.
        The default is big endian.
        """
        self.setup_logger(silent = False)
        self.extension = extension
        self.open(filename)

        self.logger.info('Loading {:s}'.format(
            self.filename))
        s = 'File size: {}'.format(byte2human(self.filesize))
        if self.compressed:
            s += ' (compressed)'
        self.logger.info(s + '.')

        self.byteorder = self._check_signature()
        self.swapbyteorder = self.byteorder != self.native_byteorder
        if self.swapbyteorder:
            self.logger.info('Swapping endian.')

        self.dtype_i8 = np.dtype(np.int64).newbyteorder(self.byteorder)
        self.dtype_u8 = np.dtype(np.uint64).newbyteorder(self.byteorder)
        self.dtype_f8 = np.dtype(np.float64).newbyteorder(self.byteorder)

        self.dtypes = np.zeros(15, dtype = np.object)
        self.dtypes[[5,13,14]] = [self.dtype_f8,
                                  self.dtype_i8,
                                  self.dtype_u8]

        self._load()

        self.close()
        self.close_logger(timing = 'Data loaded in')

    def open(self, filename):
        """
        Open the file.
        """
        self.filename = os.path.expandvars(os.path.expanduser(filename))
        if not os.path.exists(self.filename):
            fn = self.filename + '.gz'
            if os.path.exists(fn):
                self.filename = fn
            else:
                fn = self.filename + '.bz2'
                if os.path.exists(fn):
                    self.filename = fn
                else:
                    raise IOError("File not found.")
        if self.filename.endswith('.gz'):
            self.compressed = True
            self.compress_mode = 'gz'
            self.file = gzip.open(self.filename,'rb')
            pos = self.file.myfileobj.tell()
            self.file.myfileobj.seek(-4, os.SEEK_END)
            self.filesize = np.ndarray(1, dtype = "<u4", buffer = self.file.myfileobj.read(4))[0]
            self.file.myfileobj.seek(pos, os.SEEK_SET)
        elif self.filename.endswith('.bz2'):
            self.compressed = True
            self.compress_mode = 'bz2'
            self.file = bz2.BZ2File(self.filename,'rb',2**16)            
            pos = self.file.tell()
            self.file.seek(0, os.SEEK_END)
            self.filesize = self.file.tell()
            self.file.seek(pos, os.SEEK_SET)
        elif self.filename.endswith('.xz'):
            self.compressed = True
            self.compress_mode = 'xz'
            self.filesize = xz_file_size(self.filename)
            self.file = LZMAFile(self.filename,'rb')            
        else:
            self.file = open(self.filename,'rb',-1)
            self.stat = os.fstat(self.file.fileno())
            self.filesize = self.stat.st_size
            self.compressed = False

    def _check_signature(self):
        """
        Check file signature and return byte order.
        """
        le_byte_order = "<"
        be_byte_order = ">"
        signature_dtype = np.dtype(le_byte_order + "u8")
        self.signature = np.array(0xADBADBADBABDADBA,
                         dtype = signature_dtype)
        u8 = np.ndarray((),dtype = signature_dtype)
        u8.data[0:8] = self.file.read(8)
        if self.signature == u8:
            return le_byte_order
        u8.newbyteorder(be_byte_order)
        if self.signature == u8:
            return le_byte_order
        raise self.SignatureError(self.filename)

    def read_uin(self, dim = ()):
        """
        Read numpy uint64 array.
        """
        value = np.ndarray(dim, 
                           buffer = self.file.read(8*int(prod(dim))),
                           dtype = self.dtype_u8,
                           order = 'F')
        if self.swapbyteorder:
            value.byteswap(True)
        return value

    def read_int(self, dim = ()):
        """
        Read numpy int64 array.
        """
        value = np.ndarray(dim, 
                           buffer = self.file.read(8*int(prod(dim))),
                           dtype = self.dtype_i8,
                           order = 'F')
        if self.swapbyteorder:
            value.byteswap(True)
        return value

    def read_dbl(self, dim = ()):
        """
        Read numpy float64 array.
        """
        value = np.ndarray(dim, 
                           buffer = self.file.read(8*int(prod(dim))),
                           dtype = self.dtype_f8,
                           order = 'F')
        if self.swapbyteorder:
            value.byteswap(True)
        return value

    def read_str(self, dim = ()):
        """
        Read numpy string array.

        Unfortunately, this is somewhat inefficient in numpy as the
        maximum string size is allocated for all array entries.
        """

        n_elements = int(prod(dim))
        strlen = np.ndarray(dim, 
                            buffer = self.file.read(8*n_elements),
                            dtype = self.dtype_u8,
                            order = 'F')
        if self.swapbyteorder:
            strlen.byteswap(True)
        loadlen = strlen + 7 - np.mod(strlen+7,8)
        maxlen = int(strlen.max())
        if maxlen == 0:
            return np.zeros(dim, dtype = '|S1')
        buf = self.file.read(int(loadlen.sum()))
        value = np.ndarray(dim, 
                           dtype = '|S{:d}'.format(maxlen))
        # need check byteswap?
        first = np.ndarray(n_elements, dtype='u8')
        first[0] = 0
        first[1:] = loadlen.reshape(-1).cumsum()[:-1]
        last = first + strlen.reshape(-1)
        flat_value = value.reshape(-1)
        for i,ifirst,ilast in zip(xrange(n_elements),first,last):
            flat_value[i] = buf[ifirst:ilast]
        return value

    def read_stu(self, 
                 dim = (),
                 fieldnames = None,
                 fieldtypes = None):
        """
        Read numpy record array.

        Currently only supports 8 byte types float64, int64, uint64.
        """
        nfields = len(fieldnames)
        dtypes = np.choose(np.array(fieldtypes,dtype=np.int64),
                           tuple(self.dtypes))
        dtype = np.dtype({'names': fieldnames,
                          'formats': dtypes})
        value = np.ndarray(dim, 
                           buffer = self.file.read(8*int(nfields)*int(prod(dim))),
                           dtype = dtype,
                           order = 'F')
        if self.swapbyteorder:
            value.byteswap(True)
        return value

    def close(self):
        """Close the file."""
        self.file.close()

    def read_tail(self):
        """
        read file integrity check.
        """
        self.savedsize = self.read_uin()
        if self.filesize != self.savedsize:
            self.logger.error('file integrity seems broken')
            raise self.IntegrityError()
        self.logger.info('file integrity seems OK')

    def _load(self):
        """
        Load the data from file.       
        """
        self.version = self.read_uin()
        self.logger.info('Data version: {:6d}'.format(int(self.version)))

        self.name = self.read_str()
        self.logger.info('data set name: {:s}'.format(self.name))

        self.ncomment = self.read_uin()
        self.comments = self.read_str(self.ncomment)
        self.logger.info(''.ljust(58,"="))
        for comment in self.comments:
            self.logger.info('COMMENT: {:s}'.format(comment))
        self.logger.info(''.ljust(58,"="))

        self.nstar  = int(self.read_uin())
        self.nfield = int(self.read_uin())
        self.nabu   = int(self.read_uin())

        if self.version < 10100:
            iabutype = self.read_uin()
            if iabutype != 1:
                self.logger.error('currently only supporting element data (type 1)')
                raise self.VersionError()
            self.abundance_type  = 2
            self.abundance_class = 2
            self.abundance_unit  = 7
            self.abundance_total = 0
            self.abundance_norm  = 'Lod03'
            self.abundance_data  = 0
            self.abundance_sum   = 1
        else:
            self.abundance_type  = int(self.read_uin())
            self.abundance_class = int(self.read_uin())
            self.abundance_unit  = int(self.read_uin())
            self.abundance_total = int(self.read_uin())
            self.abundance_norm  = self.read_str()
            self.abundance_data  = int(self.read_uin())
            self.abundance_sum   = int(self.read_uin())

        self.logger.info('data sets:      {:6d}'.format(int(self.nstar)))
        self.logger.info('abundance sets: {:6d}'.format(int(self.nabu)))
        self.logger.info(''.ljust(58,"-"))
        self.logger.info('abundance type:  {:1d} - {:s}'.format(int(self.abundance_type ),self.abundance_type_names[ self.abundance_type]))
        self.logger.info('abundance class: {:1d} - {:s}'.format(int(self.abundance_class),self.abundance_class_names[self.abundance_class]))
        self.logger.info('abundance unit:  {:1d} - {:s}'.format(int(self.abundance_unit ),self.abundance_unit_names[ self.abundance_unit]))
        self.logger.info('abundance total: {:1d} - {:s}'.format(int(self.abundance_total),self.abundance_total_names[self.abundance_total]))
        s = self.abundance_norm
        if s == '': 
            s = '(NONE)'
        self.logger.info('abundance norm:      {:s}'.format(s))
        self.logger.info('abundance data:  {:1d} - {:s}'.format(int(self.abundance_data ),self.abundance_data_names[self.abundance_data]))
        self.logger.info('abundance sum:   {:1d} - {:s}'.format(int(self.abundance_sum  ),self.abundance_sum_names[ self.abundance_sum]))

        self.fieldnames   = self.read_str(self.nfield)
        self.fieldunits   = self.read_str(self.nfield)
        self.fieldtypes   = self.read_uin(self.nfield)
        self.fieldformats = self.read_str(self.nfield)
        self.fieldflags   = self.read_uin(self.nfield)
        
        for i in xrange(self.nfield):
            self.fieldnames[i]   = self.fieldnames[i].strip()
            self.fieldunits[i]   = self.fieldunits[i].strip()
            self.fieldformats[i] = self.fieldformats[i].strip()

        self.logger.info(''.ljust(58,"-"))            
        self.logger.info('{:d} data fields: '.format(self.nfield))

        l1 = max(len(x) for x in self.fieldnames)
        l2 = max(len(x) for x in self.fieldunits)
        l3 = max(len(x) for x in self.typenames)

        format = "{{:{:d}s}} {{:{:d}s}} {{:{:d}s}} {{:s}}".format(l1,l2+2,l3+2)
        for ifield in xrange(self.nfield):
            self.logger.info(format.format(
                    self.fieldnames[ifield],
                    '['+self.fieldunits[ifield]+']',
                    '('+self.typenames[self.fieldtypes[ifield]]+')',
                    '<'+self.flagnames[self.fieldflags[ifield]]+'>'))
            if len(np.argwhere(self.typenames[self.fieldtypes[ifield]] 
                               == np.array(['DOUBLE','LONG64','ULONG64']))) == 0:
                self.logger.error('data type not yet supported')
                self.logger.error('only supporting 8-byte scalar data types')
                raise self.VersionError()

        self.abu_Z = self.read_uin(self.nabu)
        if self.version < 10100:
            self.abu_A = np.ndarray(nabu,dtype=np.uint64)
            self.abu_E = np.ndarray(nabu,dtype=np.uint64)
        else:
            self.abu_A = self.read_uin(self.nabu)
            self.abu_E = self.read_uin(self.nabu)

        if self.abundance_type == 0:
            self.ions = np.array([isotope.Ion(Z=int(self.abu_Z[i]),
                                              A=int(self.abu_A[i]),
                                              E=int(self.abu_E[i])) 
                                  for i in xrange(self.nabu)])
        elif self.abundance_type == 1:
            self.ions = np.array([isotope.Ion(Z=int(self.abu_Z[i]),
                                              A=int(self.abu_A[i])) 
                                  for i in xrange(self.nabu)])
        elif self.abundance_type == 2:
            self.ions = np.array([isotope.Ion(Z=int(self.abu_Z[i])) 
                                  for i in xrange(self.nabu)])
        elif self.abundance_type == 3:
            self.ions = np.array([isotope.Ion(A=int(self.abu_A[i])) 
                                  for i in xrange(self.nabu)])
        elif self.abundance_type == 4:
            self.ions = np.array([isotope.Ion(N=int(self.abu_A[i])) 
                                  for i in xrange(self.nabu)])
        else:
            self.logger.error('anundance type not defined.')
            raise self.DataError()
            
        self.field_data = self.read_stu(self.nstar, 
                                        self.fieldnames, 
                                        self.fieldtypes)

        l1 = max(len(x) for x in self.fieldnames) 
        l2 = 0
        l3 = 0
        nvalues = np.zeros(self.nfield, dtype=np.uint64)
        values = np.ndarray((self.nfield, self.nstar), dtype = np.float64)
        ivalues = np.ndarray((self.nfield, self.nstar), dtype=np.uint64)
        re_len = re.compile('^[A-Z]([0-9]+)',flags=re.I)
        for ifield in xrange(self.nfield):
            # values
            v = self.field_data[self.fieldnames[ifield]]
            vs = v.argsort()            
            vv,vu,vx = np.unique(v,
                                 return_index=True,
                                 return_inverse=True)
            nv = len(vv)
            values[ifield,0:nv] = vv
            nvalues[ifield] = nv
            vx = np.insert(vx,0,-1) 
            for iv in xrange(nv):
                ivalues[ifield,vs[vx[iv]+1:vx[iv+1]+1]] = iv

            # output formatting
            flen = int(re_len.findall(self.fieldformats[ifield])[0])
            l2 = max(l2,flen)
            l3 = max(l3,len('{:d}'.format(nv)))
        nvalues_max = max(nvalues)
        values = values[:,0:nvalues_max]

        # convert to python formats
        self.field_formats = self.fieldformats.copy()
        for i in xrange(self.nfield):
            self.field_formats[i] = self.fieldformats[i][1:]
            if self.fieldformats[0][0] == 'F':
                self.field_formats[i] += 'F'
            elif self.fieldformats[0][0] == 'I':
                self.field_formats[i] += 'D'
            else:
                self.logger.error('Format type not supported.')
                raise AttributeError()
            
        xpar = np.argwhere(self.fieldflags == 0)
        if len(xpar) > 0:
            self.logger.info(''.ljust(58,"-"))            
            self.logger.info('PARAMETER RANGES:')
        for ip in xpar.flat:
            fmax = max(values[ip,0:nvalues[ip]])
            fmin = min(values[ip,0:nvalues[ip]])
            line=(self.fieldnames[ip]+': ',
                  ("{:"+self.field_formats[ip]+"}").format(fmin),
                  ("{:"+self.field_formats[ip]+"}").format(fmax),
                  "{:d}".format(int(nvalues[ip])))
            format="{{:<{:d}s}} {{:>{:d}s}} ... {{:>{:d}s}} ({{:>{:d}s}} values)".format(l1+2,l2,l2,l3)
            self.logger.info(format.format(*line))

        xprop = np.argwhere(self.fieldflags != 0)
        if len(xprop) > 0:
            self.logger.info(''.ljust(58,"-"))            
            self.logger.info('PROPERTY RANGES:')
        for ip in xprop.flat:
            fmax = max(values[ip,0:nvalues[ip]])
            fmin = min(values[ip,0:nvalues[ip]])
            line=(self.fieldnames[ip]+': ',
                  ("{:"+self.field_formats[ip]+"}").format(fmin),
                  ("{:"+self.field_formats[ip]+"}").format(fmax),
                  "{:d}".format(int(nvalues[ip])))
            format="{{:<{:d}s}} {{:>{:d}s}} ... {{:>{:d}s}} ({{:>{:d}s}} values)".format(l1+2,l2,l2,l3)
            self.logger.info(format.format(*line))


        if len(xpar) > 0:
            self.logger.info(''.ljust(58,"-"))            
            self.logger.info('PARAMETER VALUES:')
        for ip in xpar.flat:
            self.logger.info(self.fieldnames[ip]+':')
            flen = int(re_len.findall(self.fieldformats[ifield])[0])
            s = ''
            f = " {:"+self.field_formats[ip]+"}"
            for id in xrange(nvalues[ip]):
                if len(s) >= 50:
                    self.logger.info(s)
                    s = ''
                s += f.format(values[ip,id])
            self.logger.info(s)

        maxpropvalues = 100
        if len(xprop) > 0:
            self.logger.info(''.ljust(58,"-"))            
            self.logger.info('PROPERTY VALUES:')
        for ip in xprop.flat:
            self.logger.info(self.fieldnames[ip]+':')
            if nvalues[ip] > maxpropvalues:
                self.logger.info('(more than {:d} values)'.format(maxpropvalues))
            else:
                flen = int(re_len.findall(self.fieldformats[ifield])[0])
                s = ''
                f = " {:"+self.field_formats[ip]+"}"
                for id in xrange(nvalues[ip]):
                    if len(s) >= 50:
                        self.logger.info(s)
                        s = ''
                    s += f.format(values[ip,id])
                self.logger.info(s)
        self.logger.info(''.ljust(58,"-"))            

        self.abu_data = self.read_dbl((self.nabu, self.nstar))
        self.nvalues = nvalues
        self.values  =  values
        self.indices = ivalues 

        if len(np.nonzero(self.abu_data.sum(0)==0)[0]) > 0:
            self.logger.error('found zero data sets.')
            raise self.DataError()

        self.read_tail()