import os
import shelve
import timeit
import json
import spacy
import shutil
from spacy.cli import train
from spacy.matcher import Matcher
from spacy.tokens import Doc, Span, Token
from CvCAGenericLogger import get_logger_handler
from CvEEConfigHelper import *
from CvEEUtil import *
from CvPyUtils import printelapsedtime, ensureUtf
from CvSpacyModelTrainer import *
from ilock import ILock, ILockException
import traceback


class CustomComponent(object):
    def __init__(
        self, primarynlp, minaccuracy=50, testsplit=30, name="", params={}
    ):  # name_label say cc_cvterms_cvterm_test
        try:
            _, self.componentname, self.label_, self.mode = name.split("_")
            self.primarynlp = primarynlp
            Doc.set_extension("customterm", default={})
            self.modelname = os.path.join("CustomTrainedModels", self.componentname)
            self.__name__ = (
                self.componentname + self.mode
            )  # cannot have 2 components with same name in the pipeline, so append mode
            self.perfCounter = util_get_dict_attr(params, "perfCounter")
            self.params = params
            self.trainingstart = "bEETrainmodel_" + self.componentname
            if self.mode == "train":
                self.minaccuracy = minaccuracy
                self.testsplit = testsplit
                self.n_iter = 10
                self.examplebuffer = []
                self.bufferflush = 50  # TODO: make configurable
                self.trainthreshold = 4000  # TODO: make configurable
            self.loadmodel()
            self.loadmatcher()
        except Exception as e:
            tb = traceback.format_exc()
            self.params["logger"].error("Error in CvTerms component {}. {}".format(e, tb))

    def loadmodel(self):
        if self.mode == "train":
            if loadRegValue(self.trainingstart, 0) == 1:
                if os.path.exists(
                    os.path.join(self.modelname, "meta.json")
                ):  # first check current dir for the custom trained model for training
                    self.nlp = spacy.load(self.modelname, vocab=self.primarynlp.vocab)
                else:  # then check anaconda dir for installed custom model for training on
                    PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()
                    if os.path.exists(
                        os.path.join(PYTHON_SITE_PACKAGES_DIR, self.componentname, "meta.json")
                    ):
                        self.nlp = spacy.load(
                            os.path.join(PYTHON_SITE_PACKAGES_DIR, self.componentname)
                        )
                    else:  # load the default model to train on
                        SPACY_MODELS_INSTALL_DIR = loadRegValue(
                            SPACY_MODELS_REG_KEY,
                            os.path.join(
                                PYTHON_SITE_PACKAGES_DIR, SPACY_MODELS_MAP["en_core_web_lg"]
                            ),
                            type=str,
                        )
                        self.nlp = spacy.load(
                            SPACY_MODELS_INSTALL_DIR, vocab=self.primarynlp.vocab
                        )  # spacy.blank('en')
            else:
                self.nlp = None
        else:  # test mode
            if os.path.exists(
                os.path.join(self.modelname, "meta.json")
            ):  # first check current base dir
                self.nlp = spacy.load(
                    self.modelname, vocab=self.primarynlp.vocab
                )  # for now language is english, enhance when more languages are needed
            else:
                PYTHON_SITE_PACKAGES_DIR = (
                    getPythonSitePackagesDir()
                )  # then check anaconda dir for installed custom model
                if os.path.exists(
                    os.path.join(PYTHON_SITE_PACKAGES_DIR, self.componentname, "meta.json")
                ):
                    self.nlp = spacy.load(
                        os.path.join(PYTHON_SITE_PACKAGES_DIR, self.componentname),
                        vocab=self.primarynlp.vocab,
                    )
                else:
                    self.nlp = None  # set to none if no custom trained model is found

    def loadmatcher(self):
        if self.nlp is not None:
            self.matcher = Matcher(self.nlp.vocab)
            self.bootstrapfile = self.componentname + ".list"
            st = timeit.default_timer()
            with open(
                self.bootstrapfile
            ) as bsf:  # expect a single keyword or a multi word phrase in each line
                count = 0
                exceptioncount = 0
                patterns = []
                keywords = (
                    bsf.read().splitlines()
                )  # read 64KB every time if file is big, this will increase the speed and keep memory low
                keywords = set(keywords)
                for keyword in keywords:
                    try:
                        patterns.append({"LOWER": keyword.strip().lower()})
                        count += 1
                    except ValueError as e:
                        exceptioncount += 1
                [self.matcher.add(self.label_, None, [pattern]) for pattern in patterns]
            printelapsedtime(
                st,
                "Time to add phrases for gazzetteer {}, length {}, exceptions {}".format(
                    self.__name__, count, exceptioncount
                ),
            )

    def __call__(self, doc):
        if self.mode == "train":
            if loadRegValue(self.trainingstart, 0) == 1:
                try:
                    with ILock("NERTrainingLock", timeout=1):
                        if loadRegValue(self.trainingstart, 0) == 1:
                            if self.nlp == None:
                                self.loadmodel()
                                self.loadmatcher()
                            if self.perfCounter is not None:
                                self.perfCounter.start_stopwatch(self.__name__ + "_Train")
                                for para in doc.sents:
                                    for text in para.text.splitlines():
                                        if len(text.strip()) > 500 or len(text.strip()) == 0:
                                            continue
                                        matches = self.matcher(self.nlp.make_doc(text.strip()))
                                        entitydict = []
                                        for matchlabelid, start, end in matches:
                                            matchlabel = self.nlp.vocab.strings[matchlabelid]
                                            if matchlabel == self.label_:
                                                entitydict.append((start, end, self.label_))
                                        if (
                                            len(entitydict) > 0
                                        ):  # add the annotations only if found and text len less than 1000, else training will crash
                                            self.examplebuffer.append(
                                                (text, {"entities": entitydict})
                                            )
                                        if (
                                            len(self.examplebuffer) % self.bufferflush == 0
                                        ):  # do not flush to shelf for every example, flush in batches of default size 50 examples
                                            self.trainingshelf = shelve.open(self.__name__)
                                            label = self.label_.encode("ascii", "ignore")
                                            if label in self.trainingshelf:
                                                trainingexamples = self.trainingshelf[label]
                                            else:
                                                trainingexamples = []
                                            trainingexamples.extend(self.examplebuffer)
                                            self.examplebuffer = []
                                            self.trainingshelf[label] = trainingexamples
                                            self.trainingshelf.close()
                                            if (
                                                len(trainingexamples) > self.trainthreshold
                                            ):  # train when there are good number of examples, default=100
                                                # trainidx, testidx = train_test_split(range(0,len(trainingexamples)),test_size=float(self.testsplit)/100,random_state=42)
                                                # nlptrained = TrainNER(self.nlp,[trainingexamples[i] for i in trainidx])
                                                testingsetpath = createtraintestdevsets(
                                                    self.primarynlp,
                                                    trainingexamples,
                                                    self.componentname,
                                                )
                                                stagedir = os.path.join(self.modelname, "stage")
                                                if not os.path.exists(stagedir):
                                                    os.makedirs(stagedir)
                                                train(
                                                    lang="en",
                                                    output_dir=stagedir,
                                                    train_data=self.componentname
                                                    + "_training.json",
                                                    dev_data=self.componentname + "_dev.json",
                                                    vectors=str(self.primarynlp.path),
                                                    no_parser=True,
                                                    no_tagger=True,
                                                    n_iter=self.n_iter,
                                                )
                                                nlptrained = LoadModel(
                                                    os.path.join(stagedir, "model-final")
                                                )
                                                precision, recall, fmeasure = evaluate(
                                                    nlptrained, testingsetpath
                                                )
                                                self.params["logger"].info(
                                                    "Length of training examples {0}, precision : {1:.3f}, recall : {2:.3f}, fmeasure : {3:.3f}".format(
                                                        len(trainingexamples),
                                                        precision,
                                                        recall,
                                                        fmeasure,
                                                    )
                                                )
                                                if fmeasure >= self.minaccuracy:
                                                    self.nlp = nlptrained
                                                    SaveModel(
                                                        nlptrained, self.modelname
                                                    )  # later we can add versioning code to maintain 3 or configurable number of versions
                                                    shutil.rmtree(
                                                        os.path.join(self.modelname, "vocab")
                                                    )
                                                    setRegValue(
                                                        self.trainingstart, 0
                                                    )  # set the reg key to stop training only when it satisfies accuracy criteria
                                                shutil.rmtree(stagedir, ignore_errors=True)
                                                if (
                                                    self.perfCounter is not None
                                                ):  # stop the perfcounter since we can exit the loop prematurely from collection point of view
                                                    self.perfCounter.stop_stopwatch(
                                                        self.__name__ + "_Train", len(doc.text)
                                                    )
                                                return doc
                        if self.perfCounter is not None:
                            self.perfCounter.stop_stopwatch(self.__name__ + "_Train", len(doc.text))
                except ILockException as e:
                    self.params["logger"].info("Training Lock not obtained, passing through...")

            return doc  # return doc if this doc has only helped collect some examples and the number of examples is not enough OR training is already done. If we want to initiate training again, then set the reg key to 0
        else:
            if loadRegValue(self.trainingstart, 0) == 0:
                if self.perfCounter is not None:
                    self.perfCounter.start_stopwatch(self.__name__ + "_Test")
                if self.nlp == None:
                    self.loadmodel()
                    if self.nlp == None:  # no trained model yet
                        return doc

                entities = []
                termdoc = self.nlp(doc.text)
                for ent in termdoc.ents:
                    if ent.label_ == self.label_ and len(ent.text.strip()) > 1:
                        entities.append(ent)
                doc._.customterm[
                    self.label_
                ] = entities  # caveat is the spans should have same offsets in doc and termdoc. As long as we train on same model we should not have an issue.
                doc._.customterm[self.label_ + "matcher"] = (self.matcher, self.nlp)
                if self.perfCounter is not None:
                    self.perfCounter.stop_stopwatch(self.__name__ + "_Test", len(doc.text))

            return doc
