import os, re, json
import pandas as pd
import pickle, requests
import nltk
from nltk.corpus import wordnet, words
from CvEEConfigHelper import PERSON_TECHTERM_CLASSIFIER_MODEL
from CvEEConfigHelper import getPythonSitePackagesDir
from CvEEConfigHelper import IS_ADDRESS_CLASSIFIER_MODEL
from CvNERSolrTaggerComponent import SolrTaggerComponent
from pathlib import Path


class PersonTechTermFilter(object):
    def __init__(self, params={}):
        self.LOGGER_Generic = params["logger"]
        self.categories = ["tech_term", "person"]
        PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()
        try:
            self.classify = pickle.load(
                open(
                    os.path.join(PYTHON_SITE_PACKAGES_DIR, Path(PERSON_TECHTERM_CLASSIFIER_MODEL)),
                    "rb",
                ),
                encoding="latin-1",
            )
            with open("surnames_extracted.pkl", "rb") as f:
                self.SURNAMES = pickle.loads(f.read(), encoding="utf-8")
        except Exception as e:
            self.LOGGER_Generic.error(
                "Failed to load person vs techincal term classifier model. Exception {}".format(e)
            )

    @staticmethod
    def merge_dict(train_dict=None, feature_dict=None):
        if feature_dict is None:
            return
        for key, value in list(feature_dict.items()):
            if key not in train_dict:
                train_dict[key] = []
            train_dict[key].append(value)

    def get_features(self, name):
        name = name.lower()
        VOWELS = ["a", "e", "i", "o", "u"]
        features = {}
        features["first_letter"] = ord(name[0]) - 97
        features["last_letter"] = ord(name[-1]) - 97
        features["vowel_in_first_three"] = len(set(name[:3]).intersection(set(VOWELS)))
        features["vowel_in_last_three"] = len(set(name[-3:]).intersection(set(VOWELS)))
        features["vowels"] = sum([1 for ch in name if ch in VOWELS])
        for i in range(26):
            features["alpha_" + chr(97 + i)] = len(re.findall(chr(97 + i), name))
        for i in range(26):
            for j in range(26):
                features["alpha_" + chr(97 + i) + chr(97 + j)] = len(
                    re.findall(chr(97 + i) + chr(97 + j), name)
                )
        features["is_surname"] = (
            1 if len(self.SURNAMES.extract_keywords(name, span_info=False)) > 0 else 0
        )
        features["total_chars"] = len(name)
        features["is_root_word"] = int(wordnet.morphy(name) is not None)
        return features

    def filter_person(self, text):
        try:
            test_dict = {}
            self.merge_dict(test_dict, self.get_features(text))
            test_df = pd.DataFrame.from_dict(test_dict)
            test_df = test_df
            return self.categories[1 if self.classify.predict_proba(test_df)[0][1] > 0.90 else 0]
        except CvEETimeout.Timeout.ProcessingTimedOut as e:
            raise
        except Exception as e:
            self.LOGGER_Generic.debug(
                "Failed to classify with cv_person_techterm_model. Exception {}".format(e)
            )


class IsAddressClassifier(object):

    abbreviations_regex = r"\b({})\b"
    most_common_words = ""

    def __init__(self, params={}):
        self.LOGGER_Generic = params["logger"]
        self.categories = ["address", "non_address"]
        PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()
        nltk.data.path.append(
            os.path.join(os.getcwd(), Path("../ContentAnalyzer/bin/nltk_models/"))
        )
        IsAddressClassifier.most_common_words = set(words.words())
        try:
            self.classify = pickle.load(
                open(
                    os.path.join(PYTHON_SITE_PACKAGES_DIR, Path(IS_ADDRESS_CLASSIFIER_MODEL)), "rb"
                ),
                encoding="latin-1",
            )
        except Exception as e:
            self.LOGGER_Generic.error(
                "Failed to load address classifier model. Exception {}".format(e)
            )
            return
        try:
            with open("location_abbr.list") as f:
                abbrs = f.read()
                IsAddressClassifier.abbreviations_regex = IsAddressClassifier.abbreviations_regex.format(
                    "|".join(abbrs.split("\n"))
                )
        except Exception as e:
            self.LOGGER_Generic.error(
                "Failed to load location abbreviations. Exception {}".format(e)
            )

    @staticmethod
    def merge_dict(train_dict=None, feature_dict=None):
        if feature_dict is None:
            return
        for key, value in list(feature_dict.items()):
            if key not in train_dict:
                train_dict[key] = []
            train_dict[key].append(value)

    @classmethod
    def get_features(cls, text):
        text = re.sub(r"[^\x00-\x7F]+", "", text)
        text = re.sub(r"([^a-z0-9\s])", r" \1 ", text, flags=re.I)
        text = re.sub("\s+", " ", text)
        text = text.strip()
        if type(text) == bytes:
            text = text.decode("utf-8")
        text = text.lower()
        tagger_response = SolrTaggerComponent.get_processed_response(
            SolrTaggerComponent.get_tagger_matches(text)
        )

        features = {}
        features["GPE_IN"] = len(tagger_response["GEO_IN"])
        features["GPE_DEP"] = len(tagger_response["GEO_DEP"])
        features["GPE_OTHERS"] = len(tagger_response["GEO_OTHERS"])
        features["ENGLISH_WORDS"] = sum(
            [1 for word in text.split() if word in IsAddressClassifier.most_common_words]
        )
        features["NUM"] = len(re.findall(r"\b[0-9]+\b", text))
        features["abbreviations"] = len(re.findall(IsAddressClassifier.abbreviations_regex, text))
        # features["num_tokens"] = len(text.split())
        # for punct in [",", "|", "-"]:
        #     features["punct_" + punct] = len(re.findall(re.escape(punct), text))
        return features

    def filter_address(self, text):
        try:
            test_dict = {}
            self.merge_dict(test_dict, IsAddressClassifier.get_features(text))
            test_df = pd.DataFrame.from_dict(test_dict)
            return self.classify.predict(test_df)[0] == 1
        except CvEETimeout.Timeout.ProcessingTimedOut as e:
            raise
        except Exception as e:
            self.LOGGER_Generic.debug(
                "Failed to classify with IsAddressClassifierModel. Exception {}".format(e)
            )
            return True


if __name__ == "__main__":
    params = {}
    from CvCAGenericLogger import get_logger_handler

    logger_options = {"ROTATING_BACKUP_COUNT": 5, "ROTATING_MAX_BYTES": 5 * 1024 * 1024}
    LOGGER_Generic = get_logger_handler(
        "CvCIEntityExtractionCsApi.dll", "ContentAnalyzer", logger_options
    )
    params["logger"] = LOGGER_Generic
    person_filter = PersonTechTermFilter(params)
    import sys

    if len(sys.argv) > 1:
        texts = sys.argv[1:]
    else:
        texts = [
            "Webconsole",
            "photoshop",
            "afile",
            "niumber",
            "ida",
            "LRU files",
            "bloombergs",
            "Bloomberg",
            "Lucene",
            "metrowerks",
            "emulex",
            "inode",
            "Testcase",
            "diff",
            "singleton",
            "Batchlite",
            "Fiddler",
            "Agedbyindexing",
            "multistream",
            "Mongo",
            "fla",
            "solari",
            "Testcases",
            "ene",
            "solr",
            "debian",
            "len",
            "del",
            "diff",
            "mongo",
            "ignore",
            "putty",
            "Scala",
            "jenkins",
            "STRT RST",
            "NAS IDA",
        ]
    for text in texts:
        category = person_filter.filter_person(text.lower())
        if category == "person":
            print(text, " : ", person_filter.filter_person(text))

    is_address_classifier = IsAddressClassifier(params)
    address_texts = [
        "Administrate ASIAINFO R&D center more than",
        "Application Developer Lehman Brothers , New York",
        "Web Engineer , iMotors , San Francisco",
    ]
    for text in address_texts:
        print(f"{text} : {is_address_classifier.filter_address(text)}")
