import CvEETimeout
import pickle
import re
import gib_detect_train
import json
import pandas as pd
from CvEEConfigHelper import loadRegValue, getPythonSitePackagesDir, NER_GLOBAL
from spacy.tokens import Span
from expletives import badwords
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from CvEEUtil import util_get_dict_attr
from CvNERCustomModels import PersonTechTermFilter
from wordsegment import load, segment


class FilterEntsComponent(object):
    def __init__(self, loadgibdetector=True, params={}):
        self.LOGGER_Generic = params["logger"]
        self.model_data = pickle.load(open("gib_model_py37.pki", "rb"), encoding="utf-8")
        self.excludepattern = re.compile("^\s+$")
        self.exclude_chars_regex = r"[0-9!\"#\$%&\(\)\*\+\/:;<=>\?@\[\\\]\^_`\{\|\}~,]+"
        self.exclude_chars_beggining = "^[^a-zA-Z]+]"
        self.exclude_chars_end = "[^a-zA-Z]+$"
        self.exclude_chars = re.compile(self.exclude_chars_regex)
        self.perfCounter = params.get("perfCounter", None)
        self.most_common_words = set()
        from nltk.corpus import words

        self.most_common_words = set(words.words())
        self.lemmatizer = WordNetLemmatizer()
        self.person_filter = PersonTechTermFilter(params)
        self.word_exclusion_list = dict()
        # load wordsegment module
        load()
        exclude_list = loadRegValue(
            "sEESPNERExcludeList", '{"PERSON" : ["PERSON_exclusion.list"] }', type=str
        )
        """            
            sEESPNERExcludeList = '{"PERSON" : ["PERSON_exclusion.list"] }'
            create a set of exclude words for each label and use it drop false positives
        """
        if exclude_list != "":
            try:
                exclude_list = json.loads(exclude_list)
                for label, file_list in list(exclude_list.items()):
                    if label not in self.word_exclusion_list:
                        self.word_exclusion_list[label] = set()
                    for file_name in file_list:
                        with open(file_name) as f:
                            for line in f:
                                line = line[:-1].lower()
                                self.word_exclusion_list[label].add(line)
            except CvEETimeout.Timeout.ProcessingTimedOut as e:
                raise                            
            except Exception as e:
                self.LOGGER_Generic.error(
                    "Failed to parse JSON from sEESPNERExcludeList registry key. Exception : {}".format(
                        e
                    )
                )

    def __call__(self, doc):
        if self.perfCounter is not None:
            self.perfCounter.start_stopwatch("Filter_Entities")
        allents = list()
        entitytaglist = ["ORG", "PRODUCT", "SKILL", "PERSON"]
        excludelist = [
            "WORK_OF_ART",
            "FACILITY",
            "EVENT",
            "LAW",
            "PERCENT",
            "QUANTITY",
            "ORDINAL",
            "CARDINAL",
        ]
        customtermpatterns = None
        if doc.has_extension(
            "customterm"
        ):  # build only once per document as it is a document level attribute
            termlist = [
                term.text
                for key in list(doc._.customterm.keys())
                if not key.endswith("matcher")
                for term in doc._.customterm[key]
                if len(doc._.customterm[key]) > 0
            ]
            if len(termlist) > 0:
                customtermpatterns = re.compile(
                    r"\b(?:" + "|".join(re.escape(s) for s in termlist) + r")\b", re.IGNORECASE
                )
        ents_dict = dict()
        self.merge_entities(doc)
        for ent in doc.ents:
            entity_text = ent.text.strip()
            # remove special characters from beginning of text
            if ent.label_ in ("PERSON", "LOCATION"):
                entity_text = re.sub(self.exclude_chars_beggining, "", entity_text)
                # remove special characters from end of text
                entity_text = re.sub(self.exclude_chars_end, "", entity_text)
            entity_label = ent.label_
            if entity_text in ["CVTAB", "CVNEWLINE", "CVHEADER", "CVFOOTER"]:
                continue
            # drop empty entities
            elif self.excludepattern.search(entity_text) is not None:
                continue
            elif entity_label in excludelist:
                continue
            # removing duplicate entities from each label
            # in case of structure documents, we need duplicate count as well to get the correct found percentage
            if NER_GLOBAL["is_structure"] == False:
                if entity_label in ents_dict:
                    if entity_text.lower() in ents_dict[entity_label]:
                        continue
                    else:
                        ents_dict[entity_label].add(entity_text.lower())
                else:
                    ents_dict[entity_label] = set()
                    ents_dict[entity_label].add(entity_text.lower())

            # removing entities which are in word_exclusion_list
            if entity_label in self.word_exclusion_list:
                entity_text_list = [
                    self.lemmatizer.lemmatize(word.lower()) for word in entity_text.split(" ")
                ]
                if (
                    ent.has_extension("is_person_with_surname")
                    and ent._.is_person_with_surname == False
                    and len(
                        set(entity_text_list).intersection(self.word_exclusion_list[entity_label])
                    )
                    > 0
                ):
                    continue

            if entity_label not in entitytaglist:
                entitytaglist.append(entity_label)

            # PERSON entity filtering
            if (
                entity_label == "PERSON"
                and self.filter_person_entity(doc, ent, entity_text, customtermpatterns) == False
            ):
                continue

            # GPE entity filtering
            if (
                entity_label == "LOCATION"
                and self.filter_gpe_entity(doc, ent, entity_text) == False
            ):
                continue

            tokens = entity_text.split()
            profanityfound = False
            for token in tokens:
                if token in badwords:
                    profanityfound = True
                    break
            if profanityfound == True:
                continue
            allents.append(
                (ent.start, ent.end, entity_label, entitytaglist.index(entity_label), ent)
            )

        doc.ents = []
        if len(allents) > 0:
            allents = pd.DataFrame(allents)
            allents.sort_values([0, 3], inplace=True)
            allents.drop_duplicates(0, inplace=True)
            filteredents = allents[4].tolist()
            doc.ents = filteredents
        if self.perfCounter is not None:
            self.perfCounter.stop_stopwatch("Filter_Entities", len(doc.text))
        return doc

    def filter_person_entity(self, doc, ent, entity_text, customtermpatterns):
        global NER_GLOBAL
        # drop person entities with special characters
        if self.exclude_chars.search(entity_text) is not None:
            return False
        if len(entity_text) < 3:  # drop person entities of length less than 3
            return False
        # drop gibberish entities
        model_mat = self.model_data["mat"]
        threshold = self.model_data["thresh"]
        if gib_detect_train.avg_transition_prob(entity_text, model_mat) < threshold:
            return False

        """ check the casing for the person name
            reject in following
            1. if person name is having more than one upper case character in a word, except when all characters are capital.
            2. if person name is not title case and having a upper case somewhere in the word. 
            (we will miss some names like Mike DeMount in this case.)
        """
        for person_token in entity_text.split(" "):
            # drop lower case person names
            if person_token.islower():
                return False
            # check if word in not all caps
            if person_token.isupper() == False:
                # check if more than one char is capital
                caps_char_count = sum([1 if ch.isupper() else 0 for ch in person_token])
                if caps_char_count > 1:
                    return False
                # check if not title case and having an upper case somewhere
                if person_token.istitle() == False and caps_char_count > 0:
                    return False

        """
        a more strict casing check across tokens
        all tokens should be either in title case or in smaller case
        drop any person entity which is not following this casing
        avoid person names starting with abbreviations or having a single letter in between
        TODO needs to be revisit as it can drop some true positives as well
        """
        abbr_or_singleletter_token = sum(
            [
                1
                for person_token in entity_text.split(" ")
                if len(re.findall("\.+", person_token)) > 0 or len(person_token) == 1
            ]
        )
        if abbr_or_singleletter_token == 0:
            title_case_tokens = sum(
                [1 for person_token in entity_text.split(" ") if person_token.istitle()]
            )
            if title_case_tokens > 0 and title_case_tokens != len(entity_text.split(" ")):
                return False

        if ent.has_extension("is_person_with_surname") and ent._.is_person_with_surname == False:
            syn = wordnet.synsets(entity_text)
            if syn is not None and len(syn) > 0 and len(syn[0].examples()) > 0:
                return False
            # drop person entities which are common english words
            person_split = [
                self.lemmatizer.lemmatize(word.lower()) for word in entity_text.split(" ")
            ]
            if len(set(person_split).intersection(self.most_common_words)) > 0:
                return False

            """
            do following check for each token in current entity
            1. drop if email, url or number.
            2. drop entity if it is not a proper noun and not found by our PERSON gazetteer
            3. drop entity if first three letters is not having any vowel in it and not found by person gz.
            """
            is_person_gz = False
            for token_idx in range(ent.start, ent.end):
                current_token = doc[token_idx]
                if current_token.like_email or current_token.like_url or current_token.like_num:
                    return False
                if len(str(current_token)) == 1 or len(re.findall("\.+", str(current_token))) > 0:
                    continue
                if (
                    current_token.has_extension("entity_label_")
                    and current_token._.entity_label_ is not None
                    and "PERSON_GZ" not in current_token._.entity_label_
                    and (
                        current_token.pos_ != "PROPN"
                        or current_token.tag_ != "NNP"
                        or len(
                            set(current_token.text[:3]).intersection(set(["a", "e", "i", "o", "u"]))
                        )
                        == 0
                    )
                ):
                    return False
                if (
                    current_token._.entity_label_ is not None
                    and "PERSON_GZ" in current_token._.entity_label_
                ):
                    is_person_gz = True

            # in case of structure documents, avoid person tech filter as lots of entities are getting dropped in that case
            # TODO: we need to rebuild the model to look for the context as well
            if (
                NER_GLOBAL["is_structure"] == False
                and is_person_gz == False
                and len(entity_text.split()) == 1
                and self.person_filter.filter_person(entity_text.lower()) != "person"
            ):
                return False

            if is_person_gz == False:
                for token in person_split:
                    if len(segment(token)) > 1:
                        return False

            if customtermpatterns is not None:
                ispartialpersongzmatch = False
                for currentstart in range(ent.start, ent.end):
                    if (
                        doc[currentstart]._.entity_label_ is not None
                        and "PERSON_GZ" in doc[currentstart]._.entity_label_
                    ):
                        ispartialpersongzmatch = True
                if not ispartialpersongzmatch and customtermpatterns.findall(
                    ent.text
                ):  # finds partial match with custom patterns but not with person gz
                    return False
                for key in list(doc._.customterm.keys()):
                    if key.endswith("matcher"):
                        matcher = doc._.customterm[key][0]
                        nlp = doc._.customterm[key][1]
                        if matcher is not None:
                            if len(matcher(nlp.make_doc(ent.text))) > 0:
                                return False

        return True

    def filter_gpe_entity(self, doc, ent, entity_text):
        if len(entity_text) < 3:
            return False
        # TODO: this will be improved based on how bad spacy is performing for GPE.
        # drop lower case words
        if entity_text.islower():
            return False

        # drop words which have an english meaning based on their synsets examples
        if sum([1 for syn in wordnet.synsets(entity_text.lower()) if len(syn.examples()) > 0]) > 0:
            return False

        # drop location with english words
        if entity_text.lower() in self.most_common_words and len(entity_text) == 1:
            if doc[ent.start].pos_ != "PROPN":
                return False
        return True

    def merge_entities(self, doc):
        endings = dict()
        entities = doc.ents
        doc_length = len(doc)
        spans = []
        for entity in entities:
            if entity.label_ == "PERSON":
                start, end = entity.start, entity.end
                entity_label = entity.label
                if entity_label not in endings:
                    endings[entity_label] = {end: start}
                # check if there is an ending with current start for entity_label
                current_start = start
                # shifting current start in case of empty entities
                while (
                    current_start < doc_length
                    and doc[current_start : current_start + 1].text == " "
                ):
                    current_start += 1
                if current_start in endings[entity_label]:
                    start = endings[entity_label][current_start]
                    entity = Span(doc, start, end, label=entity_label)
                    spans.append(entity)
                endings[entity_label][end] = start
                doc.ents = list(doc.ents) + [entity]
        for span in spans:
            # get the old entity_labels_ for each token in current span
            # to repopulate it after the merge as it is getting lost there.
            current_entity_labels = []
            for token in span:
                if token.has_extension("entity_label_") and token._.entity_label_ is not None:
                    current_entity_labels.append(token._.entity_label_)

            span.merge()

            # repopulate entity_labels_ for each token (basically just one token after merge)
            if len(current_entity_labels) > 0:
                for token in span:
                    token._.entity_label_ = current_entity_labels
