"""
This module encapsulates the spacy model training code used in Commvault NLP solutions
"""
import random
import numpy as np
import spacy
import random
import numpy.random
from spacy import util
from pathlib import Path
from spacy.util import minibatch, decaying, compounding
from spacy.gold import GoldParse, GoldCorpus
from spacy.scorer import Scorer
import pkg_resources
from distutils.version import LooseVersion
from CvCAGenericLogger import get_logger_handler

logger_options = {"ROTATING_BACKUP_COUNT": 5, "ROTATING_MAX_BYTES": 5 * 1024 * 1024}

LOGGER_Generic = get_logger_handler(
    "CvCIEntityExtractionCsApi.dll", "ContentAnalyzer", logger_options
)


def get_batches(train_data, model_type):
    max_batch_sizes = {"tagger": 32, "parser": 16, "ner": 16, "textcat": 64}
    max_batch_size = max_batch_sizes[model_type]
    if len(train_data) < 1000:
        max_batch_size /= 2
    if len(train_data) < 500:
        max_batch_size /= 2
    batch_size = compounding(1, max_batch_size, 1.001)
    batches = minibatch(train_data, size=batch_size)
    return batches


"""
Arguments
---------
nlp : Language model/pipeline of spacy
TRAIN_DATA : List of tuples of the form <raw_text, list<entitytuple>> .
            entitytuple is a tuple of form (startoffset,endoffset,entitylabelname)
"""


def TrainNER(nlp, TRAIN_DATA):
    # if 'tagger' not in nlp.pipe_names:
    #     tagger = nlp.create_pipe('tagger')
    #     nlp.add_pipe(tagger)

    if "ner" not in nlp.pipe_names:
        ner = nlp.create_pipe("ner")
        nlp.add_pipe(ner, last=True)

    ner = nlp.entity
    trainingentities = set()
    # add labels
    for _, annotations in TRAIN_DATA:
        for ent in annotations.get("entities"):
            trainingentities.add(ent[2])

    for item in trainingentities:
        ner.add_label(item)

    n_iter = 15
    losshistory = []
    if LooseVersion(pkg_resources.get_distribution("spacy").version) > LooseVersion("1.9.9"):
        # get names of other pipes to disable them during training
        other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
        with nlp.disable_pipes(*other_pipes):  # only train NER
            optimizer = nlp.begin_training()
            for itn in range(n_iter):
                dropout = decaying(0.6, 0.2, 1e-4)
                random.shuffle(TRAIN_DATA)
                losses = {}
                batches = get_batches(TRAIN_DATA, "ner")  # minibatch(TRAIN_DATA, 8)
                for batch in batches:
                    texts, annotations = list(zip(*batch))
                    # print('texts length {}'.format(str([len(text) for text in texts])))
                    # print('annotations length {}'.format(str([len(annotation) for annotation in annotations])))
                    nlp.update(
                        texts,  # batch of texts
                        annotations,  # batch of annotations
                        drop=next(dropout),  # 0.5,  # dropout - make it harder to memorise data
                        sgd=optimizer,  # callable to update weights
                        losses=losses,
                    )
                print(
                    "Iteration={}, loss={}".format(itn, losses)
                )  # TODO logger at high debug level
                if itn % 10 == 0:
                    LOGGER_Generic.info("iteration:{}, losses{}".format(itn, losses))
                if losses == 0:
                    LOGGER_Generic.info("breaking at iteration:{} as no loss".format(itn))
                    break
                losshistory.append(losses)
                # print("LossHistory: [{}]".format(str(losshistory)))  # TODO logger at high debug level
                if (
                    len(losshistory) > 10
                    and losshistory[-10] == 1.0
                    and np.sum(np.diff(losshistory[-10:])) == 0
                ):
                    LOGGER_Generic.info(
                        "did not learn anything new for the last 10 iterations stopping training, current itn:{}, loss:{}".format(
                            itn, loss
                        )
                    )
                    break
    else:
        random.seed(0)
        # You may need to change the learning rate. It's generally difficult to
        # guess what rate you should set, especially when you have limited data.
        nlp.entity.model.learn_rate = 0.001
        for itn in range(n_iter):
            random.shuffle(TRAIN_DATA)
            loss = 0.0
            for raw_text, entity_offsets in TRAIN_DATA:
                if isinstance(raw_text, str):
                    raw_text = raw_text.decode("utf-8")
                doc = nlp.make_doc(raw_text)
                gold = GoldParse(doc, entities=entity_offsets["entities"])
                # By default, the GoldParse class assumes that the entities
                # described by offset are complete, and all other words should
                # have the tag 'O'. You can tell it to make no assumptions
                # about the tag of a word by giving it the tag '-'.
                # However, this allows a trivial solution to the current
                # learning problem: if words are either 'any tag' or 'ANIMAL',
                # the model can learn that all words can be tagged 'ANIMAL'.
                # for i in range(len(gold.ner)):
                #     if not gold.ner[i].endswith('ANIMAL'):
                #        gold.ner[i] = '-'
                nlp.tagger(doc)
                # As of 1.9, spaCy's parser now lets you supply a dropout probability
                # This might help the model generalize better from only a few
                # examples.
                loss += nlp.entity.update(doc, gold)  # , drop=0.9)
                if itn % 10 == 0:
                    LOGGER_Generic.info("iteration:{}, loss:{:.2f}".format(itn, loss))
            if loss == 0:
                LOGGER_Generic.info("breaking at iteration:{} as no loss".format(itn))
                break
            losshistory.append(loss)
            if (
                len(losshistory) > 10
                and losshistory[-10] == 1.0
                and np.sum(np.diff(losshistory[-10:])) == 0
            ):
                LOGGER_Generic.info(
                    "did not learn anything new for the last 10 iterations stopping training, current itn:{}, loss:{}".format(
                        itn, loss
                    )
                )
                break

        # This step averages the model's weights. This may or may not be good for
        # your situation --- it's empirical.
        nlp.end_training()
    LOGGER_Generic.info("finished training")
    return nlp


def TestNER(nlp, TEST_DATA):
    # test the trained model
    predscorer = Scorer()
    for text, entityoffsets in TEST_DATA:
        if isinstance(text, str):
            text = text.decode("utf-8")
        # doc = nlp.make_doc(text)
        # nlp.tagger(doc)
        # nlp.entity(doc)
        # nlp.parser(doc)
        doc = nlp(text)
        gold = GoldParse(nlp.make_doc(text), entities=entityoffsets["entities"])
        # LOGGER_Generic.error('Entities')
        # LOGGER_Generic.error([(ent.text, ent.label_) for ent in doc.ents])
        predscorer.score(doc, gold)
        # LOGGER_Generic.error('Tokens')
        # LOGGER_Generic.error([(t.text, t.ent_type_, t.ent_iob) for t in doc])

    precision = "{:.2f}".format(predscorer.ents_p)
    recall = "{:.2f}".format(predscorer.ents_r)
    fmeasure = "{:.2f}".format(predscorer.ents_f)
    LOGGER_Generic.info("precision={}, recall={}, fmeasure={}".format(precision, recall, fmeasure))
    return predscorer.ents_p, predscorer.ents_r, predscorer.ents_f


def evaluate(nlp, data_path, gpu_id=-1, gold_preproc=False, displacy_path=None, displacy_limit=25):
    """
    Evaluate a model. To render a sample of parses in a HTML file, set an
    output directory as the displacy_path argument.
    """
    seed = 0
    random.seed(seed)
    numpy.random.seed(seed)
    if gpu_id >= 0:
        util.use_gpu(gpu_id)
    util.set_env_log(False)
    data_path = util.ensure_path(data_path)
    corpus = GoldCorpus(data_path, data_path)
    dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
    scorer = nlp.evaluate(dev_docs, verbose=False)
    return scorer.ents_p, scorer.ents_r, scorer.ents_f


def SaveModel(nlp, output_dir):
    # save model to output directory
    if output_dir is not None:
        output_dir = Path(output_dir)
        if not output_dir.exists():
            output_dir.mkdir()
        nlp.to_disk(output_dir)
        LOGGER_Generic.error("Saved model to {}".format(output_dir))


def LoadModel(output_dir):
    # test the saved model
    nlp = None
    LOGGER_Generic.info("Loading from {}".format(output_dir))
    try:
        nlp = spacy.load(output_dir)
    except IOError as e:
        LOGGER_Generic.error("No trained model at {}".format(output_dir))
    return nlp
