"""
This module encapsulates the spacy pipeline customization code
"""
import os
import pickle
import re
import timeit
import traceback
from collections import defaultdict
from distutils.version import LooseVersion

import pandas as pd
import pkg_resources
import spacy
from flashtext import KeywordProcessor
from spacy.attrs import ENT_IOB
from spacy.matcher import PhraseMatcher
from spacy.tokenizer import Tokenizer
from spacy.tokens import Doc, Span, Token

from CvEEConfigHelper import (
    DOCUMENT_RESPONSE_CODES,
    DOCUMENT_CLASSIFY_MODELS_REG_KEY,
    DOCUMENT_CLASSIFY_MODELS_MAP,
    loadRegValue,
)
from CvEEDocumentClassifier import Categorizer
from CvEEUtil import util_get_dict_attr
from CvNERAddressComponent import AddressComponent
from CvNERFilterComponent import FilterEntsComponent
from CvNERSolrTaggerComponent import SolrTaggerComponent
from CvPyUtils import ensureUtf, printelapsedtime
from CvSpacyBusinessKeywordFilter import *

PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()

if LooseVersion(pkg_resources.get_distribution("spacy").version) > LooseVersion("1.9.9"):
    from spacy.pipeline import Tagger
else:
    from spacy.tagger import Tagger

# not in use anymore
class ClassifyDocsComponent(object):
    def __init__(self, nlp, params={}):
        self.nlp = nlp
        self.LOGGER_Generic = params["logger"]
        self.perfCounter = util_get_dict_attr(params, "perfCounter")
        self.use_document_classifier = False
        if "use_document_classifier" in params and params["use_document_classifier"] == True:
            self.use_document_classifier = True
        # use document classifier if use_document_classifier is set as True
        if self.use_document_classifier == True:
            DOCUMENT_CLASSIFY_MODELS_DIR = loadRegValue(
                DOCUMENT_CLASSIFY_MODELS_REG_KEY, "", type=str
            )
            try:
                models = json.loads(DOCUMENT_CLASSIFY_MODELS_DIR)
                self.classify = Categorizer(models)
            except Exception as e:
                self.LOGGER_Generic.error(
                    "Failed to parse JSON from sEEDocumentClassifyModelsInstallDir registry key. Exception : {}".format(
                        e
                    )
                )

    def __call__(self, doc):
        if self.perfCounter is not None:
            self.perfCounter.start_stopwatch("Document_Category")
        if self.use_document_classifier == True:
            category = self.classify.get_document_type(doc.text)
            # self.LOGGER_Generic.debug('Content [{}] Category [{}]'.format(doc.text[:50],category))
            doc.set_extension("_document_category", default="OTHERS")
            doc._._document_category = category
        if self.perfCounter is not None:
            self.perfCounter.stop_stopwatch("Document_Category", len(doc.text))
        return doc


class PrintComponent(object):
    def __init__(self, stagenum=1, params={}):
        self.stage = stagenum
        self.LOGGER_Generic = params["logger"]

    def __call__(self, doc):
        self.LOGGER_Generic.info("STAGE {}, doc has {} tokens".format(self.stage, len(doc)))
        self.LOGGER_Generic.info(doc.text)
        self.LOGGER_Generic.info("Finished STAGE {}".format(self.stage))
        return doc


class EntityBeamParserComponent(object):
    def __init__(self, nlp, beam_width=16, beam_density=0.0001, params={}):
        self.nlp = nlp
        self.LOGGER_Generic = params["logger"]
        self.beam_width = beam_width
        self.usebeam = util_get_dict_attr(params, "bFetchProbability", 0)
        self.scorecutoff = util_get_dict_attr(params, "sp_ner_probcutoff", 0.0)
        self.perfCounter = util_get_dict_attr(params, "perfCounter")
        self.beam_density = beam_density
        self.usetrainedmodels = util_get_dict_attr(params, "sp_ner_usetrainedmodels", 0)

        if self.usetrainedmodels == 1:
            categoryModels = loadRegValue(CATEGORY_BASED_MODELS_REG_KEY, "", type=str)
            if categoryModels != "":
                try:
                    """
                        sEECategoryBasedModels example value : 
                        {"RESUMES" : {"model_path" : "en_resumes_trained_lg"},"LOGS" : {"model_path" : "en_logs_trained_lg"},"OTHERS" : {"model_path" : "en_core_web_lg"}}
                    """
                    PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()
                    categoryModelsDict = json.loads(categoryModels)
                    for category, model in list(categoryModelsDict.items()):
                        _, self.__dict__[category] = spacy.load(
                            os.path.join(PYTHON_SITE_PACKAGES_DIR, model["model_path"]),
                            vocab=self.nlp.vocab,
                        ).remove_pipe("ner")
                except Exception as e:
                    self.LOGGER_Generic.error(
                        "Failed to parse JSON from sEECategoryBasedModels registry key. Exception {}".format(
                            e
                        )
                    )
        self.default_model = nlp.get_pipe("ner")
        Span.set_extension("nerscore", default=0.0)

    def __call__(self, doc):
        if self.perfCounter is not None:
            self.perfCounter.start_stopwatch("Beam_Parser")
        self.entity = self.default_model
        if doc._.has("_document_category"):
            category = doc._._document_category
            if category == "LOGS":
                return doc
            if category in self.__dict__:
                self.entity = self.__dict__[category]

        if self.usebeam == 1:
            pipeline = [pipe for pipe in self.nlp.pipe_names if pipe != "ner"]
            # drop ner and recreate the model as well as doc
            with self.nlp.disable_pipes(*pipeline):
                self.entity = self.nlp
            beam = self.entity.beam_parse(
                [doc], beam_width=self.beam_width, beam_density=self.beam_density
            )[0][0]
            entity_scores = defaultdict(float)
            for score, ents in self.entity.moves.get_beam_parses(beam):
                for start, end, label in ents:
                    entity_scores[(start, end, label)] += score
            json_dump = {"accepted_entities": [], "rejected_entities": []}
            for key in list(entity_scores.keys()):
                finalscore = entity_scores[key]
                start, end, label_ = key
                label = self.nlp.vocab.strings[label_]
                if end == -1:  # means till end of the document
                    end = len(doc)
                predictedentity = Span(doc, start, end, label)
                if finalscore > self.scorecutoff:
                    predictedentity._.nerscore = finalscore
                    doc.ents = list(doc.ents) + [predictedentity]
                    json_dump["accepted_entities"].append(
                        {
                            "label": predictedentity.lower_,
                            "text": predictedentity.label_,
                            "probability": finalscore,
                        }
                    )
                else:
                    json_dump["rejected_entities"].append(
                        {
                            "label": predictedentity.lower_,
                            "text": predictedentity.label_,
                            "probability": finalscore,
                        }
                    )
            self.LOGGER_Generic.debug("Beam Parser entities [{}]".format(json.dumps(json_dump)))
        else:
            doc = self.entity(doc)

        filteredents = []
        for ent in doc.ents:
            if ent.label_ == "PERSON":
                if len(ent.text.split()) == 1 and ent.lemma_ != ent.text:
                    continue
            filteredents.append(ent)

        doc.ents = filteredents

        if self.perfCounter is not None:
            self.perfCounter.stop_stopwatch("Beam_Parser", len(doc.text))
        return doc


class PersonEntityComponent(object):
    def __init__(self, nlp, params=None):
        if params == None:
            params = {}
        self.nlp = nlp
        self.LOGGER_Generic = params["logger"]
        self.label = nlp.vocab.strings["PERSON"]
        with open("flashtext_person.pkl", "rb") as f:
            self.keyword_processor = pickle.loads(f.read())

    def __call__(self, doc):
        matches = self.keyword_processor.extract_keywords(doc.text, span_info=True)
        spans = []
        endings = dict()
        entities = []
        for _start_end in matches:
            _, start, end = _start_end
            try:
                entity = doc.char_span(start, end, label=self.label)
                if end not in endings:
                    endings[end] = start
                if start - 1 in endings:
                    start = endings[start - 1]
                    entity = doc.char_span(start, end, label=self.label)
                endings[end] = start
                entities = list(entities) + [entity]
            except:
                pass
        # merge found person names with surname entity
        for entity in entities:
            try:
                next_span = doc[entity.end : entity.end + 1]
                if next_span._.is_surname == True:
                    new_entity = Span(doc, entity.start, entity.end + 1, label=self.label)
                    new_entity._.is_person_with_surname = True
                    entities = list(entities) + [new_entity]
            except:
                pass
        for entity in entities:
            try:
                if len(entity.text.split(" ")) > 1:
                    spans.append(entity)
                    doc.ents = list(doc.ents) + [entity]
            except:
                pass
        for span in spans:
            try:
                span.merge()
            except:
                pass

        return doc


class SurnameAttributeComponent(object):
    def __init__(self, nlp, params=None):
        if params == None:
            params = {}
        self.nlp = nlp
        self.LOGGER_Generic = params["logger"]
        self.label = nlp.vocab.strings["SURNAME"]
        with open("surnames_extracted.pkl", "rb") as f:
            self.keyword_processor = pickle.loads(f.read())
        Span.set_extension("is_surname", default=False)

    def __call__(self, doc):
        matches = self.keyword_processor.extract_keywords(doc.text, span_info=True)
        for _start_end in matches:
            try:
                _, start, end = _start_end
                span = doc.char_span(start, end, self.nlp.vocab.strings["PERSON"])
                doc[span.start : span.end]._.is_surname = True
            except:
                pass
        return doc


class SetEntityLabels(object):
    """
        set entity label for each span found by spacy. 
        Reason for doing this because spacy doesn't populate label_ for each doc right now.
        This info will be used in later components to resolve entity conflicts.
    """

    def __init__(self, nlp):
        self.nlp = nlp
        Token.set_extension("entity_label_", default=None)

    def __call__(self, doc):
        for ent in doc.ents:
            start = ent.start
            end = ent.end
            current_start = start
            while current_start < end:
                if doc[current_start]._.entity_label_ is None:
                    doc[current_start]._.entity_label_ = []
                doc[current_start]._.entity_label_.append(ent.label_)
                current_start += 1
        return doc


def fix_space_tags(doc):
    """
    TODO: This is a quick fix which needs to be revisited once spacy fixing it internally
    https://github.com/explosion/spaCy/issues/2870
    """
    ent_iobs = doc.to_array([ENT_IOB])
    for i, token in enumerate(doc):
        if token.is_space:
            # Sets 'O' tag (0 is None, so I is 1, O is 2)
            ent_iobs[i] = 2
    doc.from_array([ENT_IOB], ent_iobs.reshape((len(doc), 1)))
    return doc


class GazzeteerComponent(object):
    """
    We have moved all the gazzeteers to solr. This component will not get used anymore.
    """

    def __init__(self, nlp, entityname, params={}):
        self.vocab = nlp.vocab
        self.entityname = entityname
        self.label = nlp.vocab.strings[entityname]
        self.__name__ = entityname + "GazzeteerComponent"
        self.matcher = PhraseMatcher(nlp.vocab)
        self.nlp = nlp
        self.LOGGER_Generic = params["logger"]
        self.perfCounter = util_get_dict_attr(params, "perfCounter")
        patterns = self.getpatterns(entityname)
        self.LOGGER_Generic.info("Adding patterns for gazzetteer {}".format(entityname))
        st = timeit.default_timer()
        exceptioncount = 0
        for pattern in patterns:
            try:
                self.matcher.add(entityname, None, pattern)
            except ValueError as e:
                exceptioncount += 1
                # self.LOGGER_Generic.error("Exception {} for pattern {} for entity {}".format(e,pattern,entityname))
        self.scorevar = self.entityname.lower() + "score"
        Span.set_extension("pent", default="")
        Span.set_extension("is_person_with_surname", default=False)
        Span.set_extension(self.scorevar, default=0.0)
        printelapsedtime(
            st,
            "Time to add phrases for gazzetteer {}, exceptions {}".format(
                entityname, exceptioncount
            ),
        )

    def getpatterns(self, entityname):
        fname = "patterns_{}.list".format(entityname)
        with open(fname) as entityf:
            gz_text = entityf.read()
            gz_text = gz_text.split("\n")
            for line in gz_text:
                yield self.nlp.make_doc(line.strip().decode("utf-8"))

    def __call__(self, doc):
        if self.perfCounter is not None:
            self.perfCounter.start_stopwatch("Gazzeteer")
        matches = self.matcher(doc)
        spans = []
        try:
            endings = dict()
            entities = []
            for _, start, end in matches:
                entity = Span(doc, start, end, label=self.label)
                if start in endings:
                    start = endings[start]
                    entity = Span(doc, start, end, label=self.label)
                exec("entity._.{} = {}".format(self.scorevar, 1))
                current_start = start
                while current_start < end:
                    if doc[current_start]._.entity_label_ is None:
                        doc[current_start]._.entity_label_ = []
                    doc[current_start]._.entity_label_.append(self.entityname + "_GZ")
                    current_start += 1
                endings[end] = start
                entities = list(entities) + [entity]
            if self.entityname == "PERSON":
                # merge found person names with surname entity
                for entity in entities:
                    try:
                        next_span = doc[entity.end : entity.end + 1]
                        if next_span._.is_surname == True:
                            new_entity = Span(doc, entity.start, entity.end + 1, label=self.label)
                            new_entity._.is_person_with_surname = True
                            entities = list(entities) + [new_entity]
                    except:
                        pass
            for entity in entities:
                try:
                    if len(entity.text.split(" ")) > 1:
                        spans.append(entity)
                        doc.ents = list(doc.ents) + [entity]
                except:
                    pass
            # commenting out below merge token code as it is affecting subsequent gazzeteer searches
            # for span in spans:
            #     span.merge()
        except Exception as e:
            tb = traceback.format_exc()
            self.LOGGER_Generic.info("Exception in GazzeteerComponent {}\n{}".format(e, tb))
        if self.perfCounter is not None:
            self.perfCounter.stop_stopwatch("Gazzeteer", len(doc.text))
        return doc


class CustomPipeline(object):
    def __init__(
        self,
        pipeline,
        gzentities=["PERSON", "SKILL", "LANGUAGE", "ORG", "GPE", "PRODUCT"],
        params={},
    ):
        self.pipeline = pipeline.strip(",").split(",")
        self.gzentities = gzentities
        self.params = params
        self.perfCounter = None
        self.LOGGER_Generic = params["logger"]

    def __call__(self, nlp):
        printstage = 1
        nlp.add_pipe(fix_space_tags, name="fix-ner", before="ner")
        for pipe in self.pipeline:
            try:
                if pipe == "classify":
                    nlp.add_pipe(ClassifyDocsComponent(nlp, params=self.params), first=True)
                if pipe == "tagger":
                    nlp.add_pipe(nlp.tagger)
                if pipe == "parser":
                    nlp.add_pipe(nlp.parser)
                if pipe == "entity":
                    nlp.add_pipe(nlp.entity)
                if pipe == "set_entity_label":
                    nlp.add_pipe(SetEntityLabels(nlp))
                if pipe == "print":
                    nlp.add_pipe(PrintComponent(printstage, params=self.params))
                    printstage += 1
                if pipe == "gz":  # simple option to add all gazzetteers
                    for entityname in self.gzentities:
                        nlp.add_pipe(
                            GazzeteerComponent(nlp, entityname, params=self.params), last=True
                        )
                if pipe == "persongz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "PERSON", params=self.params),
                        last=True,
                        name="persongz",
                    )
                if pipe == "productgz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "PRODUCT", params=self.params),
                        last=True,
                        name="productgz",
                    )
                if pipe == "languagegz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "LANGUAGE", params=self.params),
                        last=True,
                        name="languagegz",
                    )
                if pipe == "orggz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "ORG", params=self.params), last=True, name="orggz"
                    )
                if pipe == "skillgz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "SKILL", params=self.params),
                        last=True,
                        name="skillgz",
                    )
                if pipe == "gpegz":
                    nlp.add_pipe(
                        GazzeteerComponent(nlp, "GPE", params=self.params), last=True, name="gpegz"
                    )
                if pipe.startswith("cc_"):
                    nlp.add_pipe(CustomComponent(nlp, name=pipe, params=self.params), last=True)
                if pipe == "sent":
                    nlp.add_pipe(nlp.create_pipe("sentencizer"), first=True)
                if pipe == "filtergz":
                    nlp.add_pipe(
                        FilterEntsComponent(params=self.params), last=True, name="filtergz"
                    )
                if pipe == "beamparser":
                    nlp.add_pipe(
                        EntityBeamParserComponent(nlp, params=self.params), name="beamparser"
                    )
                if pipe == "personentity":
                    nlp.add_pipe(PersonEntityComponent(nlp, params=self.params), before="filtergz")
                if pipe == "surname_attr":
                    nlp.add_pipe(SurnameAttributeComponent(nlp, params=self.params))
                if pipe == "solr_tagger":
                    nlp.add_pipe(SolrTaggerComponent(nlp, params=self.params), name="solr_tagger")
                if pipe == "address":
                    nlp.add_pipe(AddressComponent(nlp, params=self.params), name="address")

            except Exception as e:
                tb = traceback.format_exc()
                self.LOGGER_Generic.info(
                    "Exception {} adding Component {}, continuing with next component".format(
                        e, pipe
                    )
                )
        return nlp


def LoadLanguageModel(modelname="en_core_web_lg", pipeline="tagger,entity,parser", params=None):
    LOGGER_Generic = params["logger"]
    LOGGER_Generic.info("loading language model...")
    st = timeit.default_timer()
    custompipeline = CustomPipeline(pipeline, params=params)
    nlp = spacy.load(modelname)
    nlp = custompipeline(nlp)
    printelapsedtime(st, "Time to load model")
    LOGGER_Generic.info(nlp.pipe_names)
    return nlp


def print_person(func, doc):
    with open(r"person_component_wise.txt", "a") as f:
        f.write(
            "\n\n-------------------------------------------------------{}-------------------------------------------------------------\n\n".format(
                func
            )
        )
        for ent in doc.ents:
            if ent.label_ == "PERSON":
                f.write(ent.text + "\n")
    return doc
