"""
Module to make fission data for KEPLER

Currently, the full ReacLib stuff seems not needed
"""

from collections import OrderedDict

import xml.etree.ElementTree as ET

from pathlib import Path

import numpy as np

import isotope
import utils

from logged import Logged
from reaction import Reaction
from ionmap import DecayRate
from utils import CachedAttribute, cachedmethod

from .base import CDat, CDatRec, CDatNuc

class EmptyRecord(Exception):
    pass

# both ReacLibData and ReacLibRecord shoud be drived from general base
# classes that provide writing.

# add binary writing
# add reading
# general new format, interface with BDat class (bdat.py)


class ReacLibData(CDat):

    def __init__(
            self,
            reac = True,
            mode = None,
            nuc = None,
            nuc_mode = None,
            ):

        if reac == True or isinstance(reac, (str, tuple, list, Path)):
            kwargs = {}
            if not isinstance(reac, bool):
                kwargs['filename'] = reac
            if mode is not None:
                kwargs['mode'] = mode
            self.load_rate_data(**kwargs)
            self.combine_records()

        else:
            self.rates = None

        if nuc == True or isinstance(nuc, str):
            kwargs = {}
            if isinstance(nuc, str):
                kwargs['filename'] = nuc
            if nuc_mode is not None:
                kwargs['mode'] = nuc_mode
            self.nucdata = WebNucleo(**kwargs).nucdata
        else:
            self.nucdata = None

        self.comment = 'ReacLib rates'

    def load_rate_data(
            self,
            filename = (
                '/home/alex/kepler/fission/netsu_nfis_Roberts2010rates',
                '/home/alex/kepler/fission/netsu_sfis_Roberts2010rates',
                # '/home/alex/kepler/fission/netsu_panov_symmetric_0neut',
                # '/home/alex/kepler/fission/netsu_panov_symmetric_2neut',
                # '/home/alex/kepler/fission/netsu_panov_symmetric_4neut',
            ),
            mode = 'reaclib2',
            ):
        self.setup_logger(silent = False)
        rates = []
        filename = utils.iterable(filename)
        assert mode in ('reaclib', 'reaclib2')
        for fn in filename:
            fn = Path(fn).expanduser()
            with fn.open('rt') as f:
                self.logger_info(f'loading {fn} ...')
                if mode == 'reaclib':
                    chapter = int(f.readline())
                    f.readline()
                    f.readline()
                else:
                    chapter = None
                while True:
                    try:
                        rate = ReacLibRecord(f, chapter)
                        # here, for now, all are decays
                        if rate.decay is None:
                            rate.decay = True
                        if rate is not None:
                            rates.append(rate)
                    except Exception as e:
                        if isinstance(e, EmptyRecord):
                            break
                        else:
                            raise
                self.logger_info(f'... done.')
        self.rates = rates
        self.close_logger(r'loaded {} reactions in '.format(len(rates)))


    def load_nuclear_data(self,
                          filename = '~/kepler/fission/webnucleo_nuc_v2.0.xml',
                          mode = 'webnucleo',
                          ):
        """
        load nuclear data for ReacLib
        """
        self.setup_logger(silent = False)
        assert mode == 'webnucleo'

        self.logger_info(f'loading {filename} ...')
        tree = ET.parse(Path(filename).expanduser())
        t9grid = np.array(
            [
                0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
                1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
            ])

        nucdata = [ReacLibNuc(n, t9grid) for n in tree.findall('nuclide')]
        ndata = len(nucdata)

        nt9 = len(t9grid)
        self.logger_info('{} temperature points for {} nuclei.'.format(nt9, ndata))

        self.nucdata = nucdata
        self.close_logger(r'loaded {} nuclei in '.format(len(nucdata)))

    def combine_records(self):
        """
        combine resonnant and non-resonnat contributionss
        combine forward and reverse rates
        """
        # this is tricky, we do need a dictionary
        # maybe we do need to do combinations first

        self.logger_info('Combining rates ...')
        n_in = len(self.rates)
        rates = {}
        for i,r in enumerate(self.rates):
            formula = r.formula # may need mapping of compatible formulae
            label = (tuple(r.reaction.inlist), tuple(r.reaction.outlist), formula, r.decay, tuple(r.limits))
            if label in rates:
                rates[label] += (i,)
            else:
                rates[label] = (i,)
        crates = []
        for ii in rates.values():
            # TODO - some more checks, only for some formula good
            #        check gone in lable creation
            rate = self.rates[ii[0]]
            for i in ii[1:]:
                rate.fcoef.extend(self.rates[i].fcoef[2:])
                rate.icoef.extend(self.rates[i].icoef)
            crates.append(rate)
        self.rates = crates
        n_out = len(self.rates)
        self.logger_info('Combining {:d} rates to {:d}'.format(n_in, n_out))

    @CachedAttribute
    def decaydata(self):
        decays = list()
        for r in self.rates:
            if r.decay:
                decays.append(r.get_decay())
        self.logger_info(f'Found {len(decays)} decays.')
        return decays

class ReacLibRecord(CDatRec):
    # TODO - derive from a general ReacLib base class
    chapter_info = {
         1: (1, 1),
         2: (1, 2),
         3: (1, 3),
         4: (2, 1),
         5: (2, 2),
         6: (2, 3),
         7: (2, 4),
         8: (3, 1),
         9: (3, 2),
        10: (4, 2),
        11: (1, 4),
        }
    n_coef = (4, 3)
    type_names = {
        'n': 'non-resonant',
        'r': 'resonant',
        'w': 'weak',
        's': 'spontaneous',
        ' ': 'non-resonant',
        }

    def __init__(self, f, chapter = None, check = True):
        """
        Read Reaclib data record.

        Based on description in

        https://groups.nscl.msu.edu/jina/reaclib/db/help.php?topic=reaclib_format&intCurrentNum=0
        """
        if chapter is None:
            l = f.readline()
            if len(l) == 0:
                raise EmptyRecord()
            chapter = int(l.strip())
            firstline = False
        else:
            firstline = True
        # TODO - more work is needed
        # self.formula = 14 # chapter - binary rates from bdat
        self.formula = 30 # fission/spallation rates
        self.id = 0 # no dependent rates for now
        # we may want to use different formula if only 1st coefficient
        # need to collect reverse rate and formula/extra data based on chapter
        # need to add validity T range
        self.t9min = 0.

        n_in, n_out = self.chapter_info[chapter]
        n = n_in + n_out
        l = f.readline()
        if len(l) == 0 and firstline:
            raise EmptyRecord()
        nuc = [l[5+5*i:5+5*(i+1)] for i in range(n)]
        nuc_in = isotope.ufunc_ion(nuc[:n_in])

        # deal with panov modification to reaclib
        nuc_out = []
        for n in nuc[n_in:]:
            if n.count('#') > 0:
                ni,nt = n.split('#')
                nuc_out.extend([nt] * int(ni))
            else:
                nuc_out.append(n)

        nuc_out = isotope.ufunc_ion(nuc_out)
        self.label = l[43:47]
        self.type = l[47]

        # we added type ' ' not in REACLIB documentation
        assert self.type in ('n', 'r', 'w', 's', ' '), 'ERROR type: '+ l
        reverse = l[48]
        assert reverse in ('v', ' '), 'ERROR reverse: ' + l
        self.reverse = reverse == 'v' # NO!!! This means now something else
        self.loss = False
        self.Q = float(l[52:64])
        fcoef = []
        for nci in self.n_coef:
            l = f.readline()
            for i in range(nci):
                fcoef.append(l[i*13:(i+1)*13])
        fcoef = [float(fc) for fc in fcoef]
        self.reaction = Reaction(nuc_in, nuc_out, check_Z=False, check_l=False)
        # temporary for use in is_decay
        self.fcoef = [0., 0.] + fcoef
        # for bdat we need to add  T limits accordimg to rate number
        self.decay = self.is_decay()
        if self.decay:
            t9min = 0.
            t9max = 1.e99
        else:
            t9min = 0.01
            t9max = 10.
        self.fcoef[0:2] = [t9min, t9max]
        self.icoef = list()
        self.flags = 0

        # some check - this does increase time by 35%
        if check:
            assert np.sum(isotope.ufunc_A(nuc_in)) == np.sum(isotope.ufunc_A(nuc_out))
            sum_Z_in = np.sum(isotope.ufunc_Z(nuc_in))
            sum_Z_out = np.sum(isotope.ufunc_Z(nuc_out))
            if self.type == 'w':
                assert sum_Z_in in (sum_Z_out - 1, sum_Z_out + 1)
            else:
                assert sum_Z_in == sum_Z_out

    def ctrho(self, t=0., d=1.):
         p13 = 1 / 3
         t9 = t * 1.e-9
         t9m1 = 1 / t9
         t913 = t9 ** p13
         t923 = t913**2
         t953 = t923*t9
         t92 = t9**2
         t9m13 = t923 * t9m1
         t9l = np.log(t9)
         t9v7 = np.array([1, t9m1, t9m13, t913, t9, t953, t9l])
         return

    def evalf(self, t, d = 1.):
        fcoef = np.asarray(self.fcoef[2:])
        nfrate = len(fcoef)
        if self.reverse:
            nfrate = nfrate - 2
        nrate =  nfrate // 7
        frate = 0.
        j0 = 0
        t0v7 = self.ctrho(t, d)
        for _ in range(nrate):
            frate = frate \
                + np.exp(np.dot(t9v7, fcoef[j0:j0+7]))
            j0 = j0 + 7
        return frate * rho**(len(self.reaction.nuc_in) - 1)

    def is_decay(self):
        fcoef = np.asarray(self.fcoef[2:])
        nfrate = len(fcoef)
        if self.reverse:
            nfrate = nfrate - 2
        fcoef = fcoef[:nfrate]
        if len(self.reaction.nuc_in) != 1:
            return False
        decay = np.all(fcoef[1::7] == 0) and np.all(fcoef[2::7] == 0)
        if decay:
            return True
        nrate =  nfrate // 7
        decay = True
        nonzero = False
        for i in range(nrate):
            j = i * 7
            if fcoef[j+1] > 0:
                decay = False
                break
            elif fcoef[j+1] == 0 and fcoef[j+2] > 0:
                decay = False
                break
            if np.all(fcoef[j+1:j+3] == 0):
                nonzero = True
        if decay and nonzero == False:
            return False
        return decay

    def __str__(self):
        s = str(self.reaction)
        s1 = self.type_names[self.type]
        if self.reverse:
            s1 += ', reverse'
        s += ' (' + s1 + ')'
        return s

    __repr__ = __str__

    @property
    def limits(self):
        return self.fcoef[:2]

    def eval(self, t9):
        fcoef = np.asarray(self.fcoef[2:])
        return np.exp(
            fcoef[0] +
            np.sum(fcoef[1:6] * t9**((2 * np.arange(5) - 4) / 3)) +
            fcoef[6] * np.log(t9))

    def get_decay(self):
        assert self.is_decay()
        return DecayRate(self.reaction.inlist, self.reaction.outlist, np.exp(self.fcoef[2]))


class WebNucleo(Logged):
    def __init__(
            self,
            filename = '~/kepler/fission/webnucleo_nuc_v2.0.xml',
            mode = 'webnucleo',
            ):
        """
        load nuclear data for ReacLib
        """
        self.setup_logger(silent = False)
        assert mode == 'webnucleo'

        self.logger_info(f'loading {filename} ...')
        tree = ET.parse(Path(filename).expanduser())
        t9grid = np.array(
            [
                0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
                1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
            ])

        data = OrderedDict()
        for n in tree.findall('nuclide'):
            nuc = ReacLibNuc(n, t9grid)
            data[nuc.ion] = nuc
        ndata = len(data)

        nt9 = len(t9grid)
        self.logger_info(f'{nt9} temperature points for {ndata} nuclei.')

        self.data = data
        self.close_logger(f'loaded {ndata} nuclei in ')

    @property
    def nucdata(self):
        return list(self.data.keys())

    def __getitem__(self, key):
        return self.data[key]

class ReacLibNuc(CDatNuc):
    """
    load nuclear data for ReacLib
    """
    reaclib_type = 5
    def __init__(self, nuc, t9grid):
        A = int(nuc.find('a').text)
        Z = int(nuc.find('z').text)
        ion = isotope.ion(A=A, Z=Z, check=False)
        assert ion != isotope.VOID, f'INVALID {Z=} {A=}'
        label = nuc.find('source').text
        ME = float(nuc.find('mass_excess').text)
        S = float(nuc.find('spin').text)
        self.ion = ion
        self.A = A
        self.Z = Z
        self.N = A - Z
        self.formula = self.reaclib_type
        self.S = S
        self.ME = ME
        self.E = 0
        self.label = label
        data = []
        pp = nuc.find('partf_table').findall('point')
        t9 = np.array([float(p.find('t9').text) for p in pp])
        data = np.array([float(p.find('log10_partf').text) for p in pp])
        assert np.allclose(t9, t9grid)
        assert np.all(data >= 0)
        self.fcoef = data
        self.Q = round(
            self.Z * 7.28898454697355
            + self.N * 8.071317791830353
            - self.ME, 3)
        self.icoef = list()