import sys
import ujson
import os
import nltk
import re2
import plac
import traceback
import pandas as pd
import unidecode
import time
import CvEETimeout
from threading import Timer
from CvEEConfigHelper import (
    CA_ERROR_CODES,
    NER_GLOBAL,
    SPACY_MODELS_MAP,
    SPACY_MODELS_REG_KEY,
    loadRegValue,
    getPythonSitePackagesDir,
    get_available_memory,
    removeLineBreaks,
    cleanRawTextForNER,
)
from nltk.tokenize import sent_tokenize
from nameparser import HumanName
from CvCAProximityConf import ProximityBasedConfidence
from CvSpacyCustomPipeline import LoadLanguageModel
from CvEEUtil import util_get_dict_attr
from io import StringIO
from spacy.tokens import Doc
from CvNERStructureDocuments import StructuredDocumentHandler
from CvEEDocumentClassifier import DocumentCategorizer
from pathlib import Path
import mkl

MKL_NUM_THREADS = loadRegValue("bMKLNumThreads", 1, key_type=int)
mkl.set_num_threads(MKL_NUM_THREADS)


def preProcess(params={}):

    # Function which does any preprocessing, which should be called one time only
    # like loading a complex trained model for NER (a time consuming task)
    try:
        nltk.data.path.append(
            os.path.join(os.getcwd(), Path("../ContentAnalyzer/bin/nltk_models/"))
        )
        PYTHON_SITE_PACKAGES_DIR = getPythonSitePackagesDir()
        SPACY_MODELS_INSTALL_DIR = loadRegValue(
            SPACY_MODELS_REG_KEY,
            os.path.join(PYTHON_SITE_PACKAGES_DIR, Path(SPACY_MODELS_MAP["en_core_web_lg"])),
            type=str,
        )
        """
            check if availble memory is more than 3 GB to load spacy model
            current spacy model takes around 2.3 GB memory
            TODO: can think of loading the smaller models in case of less memory
        """
        nlp = None
        document_categorizer = None
        if params["is_dummy_process"] is False and get_available_memory() >= 3:
            if (
                "sp_ner_usecustompipeline" in params
                and params["sp_ner_usecustompipeline"] == True
                and "sp_ner_custompipeline" in params
            ):
                nlp = LoadLanguageModel(
                    modelname=SPACY_MODELS_INSTALL_DIR,
                    pipeline=params["sp_ner_custompipeline"],
                    params=params,
                )
            else:
                import spacy

                nlp = spacy.load(SPACY_MODELS_INSTALL_DIR)
                # document categorizer needs spacy model for content based categorization
            document_categorizer = DocumentCategorizer(nlp, params)
        extension_exclusion_list = ["log", "bz2", "gz", "tar", "rar", "zip", "7z"]

        response = {
            "nlp": nlp,
            "ErrorCode": CA_ERROR_CODES["success"],
            "extension_exclusion_list": extension_exclusion_list,
            "document_categorizer": document_categorizer,
        }
    except Exception as e:
        response = {
            "ErrorCode": CA_ERROR_CODES["SpacyPreProcessingError"],
            "ErrorMessage": "Named entity recognition processing failure: {}".format(e),
        }
    return response


def process_entity(entity, entities):
    exclude_chars_regex = "[^a-zA-Z]+"
    exclude_chars_beggining = r"^" + exclude_chars_regex
    exclude_chars_end = exclude_chars_regex + "$"
    try:
        label = entity.label_
        entity_text = entity.text[:32000].strip()
    except AttributeError:
        label = entity["label_"]
        entity_text = entity["text"][:32000].strip()
    if label in ("PERSON", "LOCATION"):
        # remove special characters from beginning of text
        entity_text = re2.sub(exclude_chars_beggining, "", entity_text)
        # remove special characters from end of text
        entity_text = re2.sub(exclude_chars_end, "", entity_text)
    label_lower = label.lower()
    if label_lower not in entities:
        entities[label_lower] = []
    # drop entities which are either empty or having spaces or newlines only
    # drop duplicate entities
    entity_text = entity_text.strip()
    if entity_text != "" and entity_text not in entities[label_lower]:
        entities[label_lower].append(entity_text)


def stitchResults(
    content, keys, results, entities, childDocs, entities_keys, entities_keywords, logger
):
    # acceptableLabels = [
    #     'PERSON',
    #     'NORP',
    #     'ORG',
    #     'GPE',
    #     'LOC',
    #     'DATE',
    #     'TIME',
    #     'PERCENT',
    #     'MONEY'
    # ]
    address_entities = list()
    try:
        doc_entities = results.ents
        if results.has_extension("address_entities"):
            address_entities = results._.address_entities
    except AttributeError:
        doc_entities = results["ents"]

    acceptableLabels = [key.lower() for key in keys]
    acceptableLabels.append("address")
    # mapping spacy org to organization entity
    if "organization" in acceptableLabels:
        acceptableLabels.append("org")
    # mapping spacy date to contextualdate entity
    if "contextualdate" in acceptableLabels:
        acceptableLabels.append("date")
    for entity in doc_entities:
        process_entity(entity, entities)

    for entity in address_entities:
        process_entity(entity, entities)

    final_keys = set(acceptableLabels).intersection(set(entities.keys()))
    final_entities = {}
    # segregated_entities = {
    #     0: {},
    #     1: {},
    #     2: {}
    # }
    for key in final_keys:
        final_entities[key.lower()] = entities[key.lower()]
    #     segregated_entities[0][key.lower()] = entities[key]

    if "date" in final_keys and "date" in final_entities:
        final_entities["contextualdate"] = final_entities.pop("date")
    if "org" in final_keys and "org" in final_entities:
        final_entities["organization"] = final_entities.pop("org")
    # if 'contextualdate' in final_entities:
    #     segregated_entities[0]['contextualdate'] = final_entities['contextualdate']

    # if 'person' in final_entities:
    # people = final_entities['person']
    # for person in people:
    # person = escapeSpecialChars(person)
    # bucket = 0
    # if 'person' not in segregated_entities[bucket]:
    #     segregated_entities[bucket]['person'] = []
    # segregated_entities[bucket]['person'].append(person)
    # name = HumanName(person)
    # if 'firstname' not in final_entities:
    #    final_entities['firstname'] = []
    # if 'lastname' not in final_entities:
    #    final_entities['lastname'] = []
    # if 'firstname' not in segregated_entities[bucket]:
    #     segregated_entities[bucket]['firstname'] = []
    # if 'lastname' not in segregated_entities[bucket]:
    #     segregated_entities[bucket]['lastname'] = []
    # final_entities['firstname'].append(name.first)
    # final_entities['lastname'].append(name.last)
    # segregated_entities[bucket]['firstname'].append(name.first)
    # segregated_entities[bucket]['lastname'].append(name.last)
    # if childDocs:
    #     return segregated_entities
    return final_entities


def escapeSpecialChars(entity):
    charlist = "[]\\^$.|?*+(){}"
    result_str = ""
    for c in entity:
        if c in charlist:
            result_str += "\\" + c
        else:
            result_str += c
    return result_str


def manipulateContentForJsonParse(text):
    text = text.replace("\\", "\\\\")
    text = text.replace("\\\\\\\\", "\\\\")
    return text


def doAnalysis(processing_input, params={}):
    module_name = "CvEENERClient"
    func_name = "doAnalysis"
    func_string = "{}::{}() - ".format(module_name, func_name)

    global NER_GLOBAL

    # Main function which will do the NER processing,
    # format the results, and return them.

    entities = {}
    entities_keywords = {}
    entities_keys = {}
    contentid = ""
    currentFileName = "Not Provided"
    fileProcessingTimeout = 2 * 60  # 2 mins
    extractionTimeout = loadRegValue("bNERProcessingTimedOut", 600, type=int)  # 10 mins
    global CA_ERROR_CODES
    try:
        currentFileName = params.get("FileName", "Not Provided")
        doc_id = currentFileName
        extension = ""
        if currentFileName == "Not Provided":
            doc_id = contentid
        else:
            extension = os.path.splitext(currentFileName)[-1][1:]
        """
            skip the document and mark it as fail if,
            1. Spacy was not loaded due to memory or cpu constraints, or
            2. System doesn't have the sufficient memory for current document to get processed
        """
        nlp = params["nlp"]
        if nlp is None or get_available_memory() < 1:
            params["logger"].error(
                "Failed to process document {} because of insufficient system resources.".format(
                    doc_id
                )
            )
            entities["ErrorCode"] = CA_ERROR_CODES["InsufficientResourcesForNER"]
            entities[
                "ErrorMessage"
            ] = "Insufficient system resources available for named entity recognition."
            return entities

        perfCounter = None

        if "contentid" in params:
            contentid = params["contentid"]

        params["logger"].debug("Processing started for File {}".format(doc_id))
        start_time = time.time()

        if "perfCounter" in params:
            perfCounter = params["perfCounter"]
            perfCounter.start_stopwatch("SP_NER")

        childDocs = False
        if "childDocs" in params:
            childDocs = params["childDocs"]

        entities_names = list(params["entities_attributes"]["entities_names"].keys())
        entities_keys = params["entities_attributes"]["entities_keys"]
        entities_keywords = params["entities_attributes"]["entities_keywords"]

        try:
            # processing_input = manipulateContentForJsonParse(processing_input)
            if "EntityExtractionFields" in params:
                params["EntityExtractionFields"] = (
                    params["EntityExtractionFields"].strip(",").split(",")
                )
            else:
                params["EntityExtractionFields"] = ["content"]

            jsonified_input = ujson.loads(processing_input)
            optTypedata = {}
            stripped_input = ""
            for item in jsonified_input["dataList"]["optTypedata"]:
                if (
                    "attrKey" in item
                    and "attrValue" in item
                    and item["attrKey"] in params["EntityExtractionFields"]
                ):
                    stripped_input += (
                        f"{item['attrKey']}{os.linesep}" if item["attrKey"] != "content" else ""
                    ) + f"{item['attrValue']}{os.linesep}"
            cleaned_input = (
                stripped_input.strip()
            )  # in case all the fields are empty, remove new line character.
            # currentFileName = optTypedata.get('FileName', 'Not Provided')
            if len(cleaned_input) == 0:
                return {
                    "ErrorCode": CA_ERROR_CODES["NoTextToProcess"],
                    "ErrorMessage": "Empty content is provided so skipping SP_NER.",
                }

            if len(cleaned_input) > params["sp_ner_size_limit"]:
                return {
                    "ErrorCode": CA_ERROR_CODES["SpacySizeExceeded"],
                    "ErrorMessage": "Content size for SP_NER greater than what the current model can support.",
                }

            if extension in params["extension_exclusion_list"]:
                return {
                    "ErrorCode": CA_ERROR_CODES["SpacyFileNotSupported"],
                    "ErrorMessage": "File type is not supported by SP_NER.",
                }

        except Exception as e:
            errorMessage = "Skipping file {} as could not parse json for temp file of doc {}. Exception {}".format(
                currentFileName, contentid, e
            )
            errorMessage += " Temp file is located at {}".format(params["ExtractorTempLocation"])
            params["logger"].error(errorMessage)
            return {
                "ErrorCode": CA_ERROR_CODES["ContentFetchFailed"],
                "ErrorMessage": "Error fetching content field for analysis.",
            }
            # leaving it here in case this code needs to be enabled again.
            # if 'ExtractorTempLocation' in params:
            #     errorMessage += " Temp file is located at {}".format(params['ExtractorTempLocation'])
            # params['logger'].error("{}. Going to use full content".format(errorMessage))
            # cleaned_input = processing_input
        # cleaned_input = unidecode.unidecode(cleaned_input).decode('utf-8')
        cleaned_input = cleaned_input.strip()
        NER_GLOBAL["is_structure"] = False
        disabled_pipes = []

        if "entitiesToExtractNER" in params:
            entities_names = []
            for entity_ in params["entitiesToExtractNER"].strip(",").split(","):
                if entity_.strip() == "":
                    continue
                entity_ = int(entity_)
                if entity_ in entities_keys:
                    entities_names.append(entities_keys[entity_])        

        if "Address" not in entities_names:
            disabled_pipes = ["address"]
        if (
            "sp_ner_structure_documents_handling" in params
            and params["sp_ner_structure_documents_handling"]
        ):
            structure_document = StructuredDocumentHandler(cleaned_input, Doc)
            if structure_document.type is not None or extension in [
                "csv",
                "tsv",
                "xls",
                "xlsx",
            ]:
                NER_GLOBAL["is_structure"] = True
                results = structure_document.perform_extraction(nlp, extension)
                entities = stitchResults(
                    cleaned_input,
                    entities_names,
                    results,
                    entities,
                    childDocs,
                    entities_keys,
                    entities_keywords,
                    params["logger"],
                )
                if perfCounter != None:
                    perfCounter.stop_stopwatch("SP_NER", len(processing_input))
        if NER_GLOBAL["is_structure"] == False:
            document_category = "others"
            categorizer = params["document_categorizer"]
            if "use_document_classifier" in params and params["use_document_classifier"]:
                document_category = categorizer.classify(cleaned_input)
                if document_category in ["logs", "source_codes"]:
                    disabled_pipes = ["ner", "entity", "beamparser"]
            """
                spacy library is behaving differently based on newline separator
                in case line endings are not according to OS spacy is missing some entities
                adding additional spaces around the newline character as it was affecting
                solr tagger matching.
            """
            cleaned_input = removeLineBreaks(cleaned_input, line_sep=" " + os.linesep + " ")
            # cleaned_input = re.sub('\\\\t', ' ', cleaned_input)
            cleaned_input = re2.sub("([',\"\-]+)", r" \1 ", cleaned_input)
            if len(cleaned_input) > params["sp_ner_size_limit"]:
                return {
                    "ErrorCode": CA_ERROR_CODES["SpacySizeExceeded"],
                    "ErrorMessage": "Content size for SP_NER greater than what the current model can support.",
                }
            if (
                "pipe" in params
                and params["pipe"]
                and "sp_ner_pipe_min_size" in params
                and len(cleaned_input) > params["sp_ner_pipe_min_size"]
            ):

                if "sp_ner_use_nltk_tokenizer" in params and params["sp_ner_use_nltk_tokenizer"]:
                    from nltk.tokenize import sent_tokenize

                    tokenized_input = sent_tokenize(cleaned_input)
                    # in case of csv files nltk tokenizer is performning really bad
                    # simply split by newlines to avoid time out
                    if len(tokenized_input) < 100:
                        tokenized_input = cleaned_input.splitlines()
                else:
                    tokenized_input = cleaned_input.splitlines()

                # TODO: calculate based on average number of characters in each line
                if len(tokenized_input) > 10000:
                    num_sentences_to_merge = 100
                    current_line = 0
                    merge_tokenized_input = []
                    while current_line < len(tokenized_input):
                        merge_tokenized_input.append(
                            os.linesep.join(
                                tokenized_input[
                                    current_line : current_line + num_sentences_to_merge
                                ]
                            )
                        )
                        current_line += num_sentences_to_merge
                else:
                    merge_tokenized_input = tokenized_input
                disabled_pipeline = [
                    pipe_name for pipe_name in disabled_pipes if pipe_name in nlp.pipe_names
                ]
                for text in merge_tokenized_input:
                    with nlp.disable_pipes(*disabled_pipeline):
                        results = nlp(text)
                        entities = stitchResults(
                            text,
                            entities_names,
                            results,
                            entities,
                            childDocs,
                            entities_keys,
                            entities_keywords,
                            params["logger"],
                        )

                # if params.has_key('displacyfilename'):
                #     from spacy import displacy
                #     if generated_results:
                #         with open(params['displacyfilename'], 'w') as file_:
                #             html = displacy.render(
                #                 generated_results, style='ent', page=True)
                #             file_.write(html)

                # for results in generated_results:
                #     entities = stitchResults(cleaned_input, entities_names, results, entities, childDocs,
                #                              entities_keys, entities_keywords, params['logger'])

            else:
                disabled_pipeline = [
                    pipe_name for pipe_name in disabled_pipes if pipe_name in nlp.pipe_names
                ]
                with nlp.disable_pipes(*disabled_pipeline):
                    results = nlp(cleaned_input)

                if "displacyfilename" in params:
                    from spacy import displacy

                    if results:
                        with open(params["displacyfilename"], "w") as file_:
                            html = displacy.render(results, style="ent", page=True)
                            file_.write(html.encode("utf-8"))
                
                entities = stitchResults(
                    cleaned_input,
                    entities_names,
                    results,
                    entities,
                    childDocs,
                    entities_keys,
                    entities_keywords,
                    params["logger"],
                )
                if perfCounter != None:
                    perfCounter.stop_stopwatch("SP_NER", len(processing_input))

        entities["ErrorCode"] = CA_ERROR_CODES["success"]
        entities["ErrorMessage"] = None
    
    except CvEETimeout.Timeout.ProcessingTimedOut as e:
        raise

    except Exception as e:
        tb = traceback.format_exc()
        params["logger"].error(
            "Error in SP_NER for attributes {} on file {}. Exception : {}{}{}".format(
                params["entities_attributes"], os.linesep, currentFileName, e, tb
            )
        )
        entities = {
            "ErrorCode": CA_ERROR_CODES["SpacyError"],
            "ErrorMessage": "SP_NER processing failure for contentid {}: {}".format(contentid, e),
        }

    end_time = time.time()
    params["logger"].debug(
        "Processing completed for File {}, Time taken {}".format(doc_id, end_time - start_time)
    )

    return entities


def printEnts(docs):
    print("Entities:")
    for doc in docs:
        pstr = None
        plist = []
        for ent in doc.ents:
            if pstr is None:
                pstr = "ent.text,ent.label_,ent.start_char,ent.end_char"
                exts = list(ent._.span_extensions.keys())
                for ext in exts:
                    pstr += ",ent._." + ext
                plist = pstr.split(",")
                pstr += os.linesep
            for item in plist:
                exec("enttxt={}".format(item))
                pstr += str(enttxt) + ","
            pstr = pstr[:-1] + os.linesep
        # print(pstr)
        pstrdf = pd.DataFrame.from_csv(StringIO(pstr), sep=",")
        print(pstrdf)


@plac.annotations(
    filename=plac.Annotation(
        help="path to the text file to run", kind="option", abbrev="fn", type=str
    ),
    displacyfilename=plac.Annotation(
        help="path to the display html file to write output into",
        kind="option",
        abbrev="dfn",
        type=str,
    ),
)
def main(filename="", displacyfilename=""):
    # Unit Tests
    from CvEEConfigHelper import REG_CONF
    import warnings

    warnings.filterwarnings("ignore")

    # import glob2
    params = {
        "sp_ner_usecustompipeline": True,
        "use_document_classifier": False,
        "sp_ner_custompipeline": "sent,set_entity_label,surname_attr,persongz,cc_cvterms_cvterm_test,filtergz,cc_cvterms_cvterm_train",
        "sp_ner_probcutoff": 0.0,
        # 'sEEDocumentClassifyModelsInstallDir': '[{"model_type" : "content_based","model_library" : "spacy","model_path" : "C:\\Anaconda\\Lib\\site-packages\\document_classify_v1.1"}]',
        # 'sEEModelsInstallDir':'C:\Anaconda\lib\site-packages',
        "sp_ner_pipe": 1,
    }

    # load the defaults - comment this code out and set specific values needed for debugging if different from default values
    for key in list(REG_CONF.keys()):
        params[key] = REG_CONF[key]["value"]

    logger_options = {"ROTATING_BACKUP_COUNT": 5, "ROTATING_MAX_BYTES": 5 * 1024 * 1024}
    from CvCAGenericLogger import get_logger_handler

    Logger_ = get_logger_handler("CvCIEntityExtractionCsApi.dll", "ContentAnalyzer", logger_options)
    params["logger"] = Logger_
    from CvCAPerfCounters import PerformanceCounter

    params["perfCounter"] = PerformanceCounter(params["logger"])
    params["is_dummy_process"] = False
    params = preProcess(params)
    if len(displacyfilename) > 0:
        params["displacyfilename"] = displacyfilename
    params["sp_ner_size_limit"] = 5 * 1024 * 1024
    params["use_document_classifier"] = True
    params["entities_attributes"] = {
        "entities_names": {"PERSON": "", "ORG": "", "LOCATION": "", "DATE": "", "Address": ""},
        "entities_keys": {"person": ""},
        "entities_keywords": [],
    }
    for key in list(REG_CONF.keys()):
        params[key] = REG_CONF[key]["value"]
    params["logger"] = Logger_
    if len(filename) > 0:
        if os.path.isdir(filename):
            for file_path in recurseDir(filename):
                _, extension = os.path.splitext(file_path)
                if os.path.isfile(file_path) and extension == ".txt":
                    print("Processing file {}".format(file_path))
                    with open(file_path, "rb") as f:
                        contentstr = f.read()
                        if isinstance(contentstr, str):
                            contentstr = contentstr.decode("utf-8")
                        contentstr = cleanRawTextForNER(contentstr).decode("utf8")
                        inputjson = {
                            "dataList": {
                                "optTypedata": [{"attrKey": "content", "attrValue": contentstr}]
                            }
                        }
                        params["displacyfilename"] = file_path + "_displacy.html"
                        print(doAnalysis(ujson.dumps(inputjson).encode("utf-8"), params))
            return
        else:
            with open(filename) as f:
                contentstr = cleanRawTextForNER(f.read())
            # contentstr = re.sub(r'([^\w\s.])',r' \1 ', contentstr)
    else:
        contentstr = "My name is Alex and I work in Commvault Systems."
    # contentstr = unidecode.unidecode(contentstr)
    inputjson = {"dataList": {"optTypedata": [{"attrKey": "content", "attrValue": contentstr}]}}
    print(doAnalysis(ujson.dumps(inputjson).encode("utf-8"), params))


def recurseDir(path):
    assert len(path) > 0
    files = []
    if os.path.isdir(path):
        children = os.listdir(path)
        for child in children:
            childpath = os.path.join(path, child)
            if os.path.isdir(childpath):
                files.extend(recurseDir(childpath))
            else:
                files.append(childpath)
    else:
        files.append(path)

    return files


if __name__ == "__main__":
    plac.call(main)
