# import ptvsd
import base64
import gc
import os
import re
import threading
import time
from threading import Event, Thread, Timer
from functools import lru_cache
import psutil

import ujson
import CvEETimeout
from CvCAGenericLogger import get_logger_handler
from CvCAPerfCounters import PerformanceCounter
from CvEEConfigHelper import (
    CA_ERROR_CODES,
    CA_TASK_LIST,
    CA_TASK_MODULE_MAP,
    HELPER_DLL,
    checkParentAndKill,
    cleanRawTextForNER,
    getBaseDir,
    is_linux,
    killProcessAndChildren,
    loadEntities,
    setRegValue,
)
from CvEEMsgQueueHandler import GenericMsgQueueCommunicator
from cvee_classification_task import start_subtask_processes


PROCESS_ID = os.getpid()

FRqueue = "FACE_REPRESENTATION"

PROFILER = {}
CLIENT_ID = 1
STOP_PROCESSING = Event()
ENTITIES_GLOBAL_CACHE = {}
ENTITIES_GLOBAL_TIME_CACHE = {}
CURRENT_FILE_LIST = {}
ENTITY_EXTRACTION_TIMEOUT = 5
CURRENT_TASK = None
DEFAULT_CONFIG = None
# Map which will keep track of the tasks which
# need a preprocessing function to be called one time.

CA_PRE_PROCESSING_LIST = {
    "SP_NER": {},
    "RER": {},
    "DE": {},
    "DOC_TAGGER": {},
    "TRAIN": {},
    "TEST": {},
    "EMAIL_TAGGER": {},
}  # feature is not getting used anymore, "VIDEO_PREVIEW": {},

def logFilenameOnTimeExceed(doc_id, timeoutDuration):
    LOGGER_Generic.error(
        f"Document [{doc_id}] processing has exceeded {timeoutDuration / 60:.2f} mins."
    )


def logFilenameOnTimeExceed(doc_id, timeoutDuration):
    LOGGER_Generic.error(
        f"Document [{doc_id}] processing has exceeded {timeoutDuration / 60:.2f} mins."
    )


class GenericCA:
    """Generic class which can call respective python workers with
    params and input and get response"""

    def doCA(self, task, text, params={}):
        global CLIENT_ID

        try:
            # import module
            module = __import__(CA_TASK_MODULE_MAP[task])

            # First, check if pre processing of the module is done
            # If it is not done, pre process the module,
            # and add response to params

            pre_processing_args = {}
            pre_processing_args["logger"] = params["logger"]
            pre_process_task = task
            if task == "DOC_TAGGER":
                if "sub_task" not in params:
                    params["sub_task"] = "TEST"
                pre_process_task = params["sub_task"]

            if pre_process_task in CA_PRE_PROCESSING_LIST:

                # Preprocess, as this is the first job called
                # for this task

                if not CA_PRE_PROCESSING_LIST[pre_process_task]:
                    if task == "VIDEO_PREVIEW":
                        pre_processing_args["clientId"] = CLIENT_ID

                    if task == "SP_NER":
                        pre_processing_args["is_dummy_process"] = params["is_dummy_process"]
                        pre_processing_args["sp_ner_usecustompipeline"] = params[
                            "sp_ner_usecustompipeline"
                        ]
                        pre_processing_args["sp_ner_custompipeline"] = params[
                            "sp_ner_custompipeline"
                        ]
                        if "use_document_classifier" in params:
                            pre_processing_args["use_document_classifier"] = params[
                                "use_document_classifier"
                            ]
                        if "sp_ner_structure_documents_handling" in params:
                            pre_processing_args["sp_ner_structure_documents_handling"] = params[
                                "sp_ner_structure_documents_handling"
                            ]
                        pre_processing_args["sp_ner_probcutoff"] = params["sp_ner_probcutoff"]
                        pre_processing_args["perfCounter"] = params["perfCounter"]

                    if task == "DOC_TAGGER":
                        pre_processing_args["sub_task"] = params["sub_task"]
                    response = module.preProcess(pre_processing_args)
                    if response["ErrorCode"] != CA_ERROR_CODES["success"]:
                        return response
                    CA_PRE_PROCESSING_LIST[pre_process_task] = response.copy()

                # Merge the results with params, and do this for every
                # request by saving it to a member variable

                for key, val in list(CA_PRE_PROCESSING_LIST[pre_process_task].items()):
                    params[key] = val

            # Pass input to doAnalysis function of each module
            currentFileName = params.get("FileName", "Not Provided")
            doc_id = currentFileName
            extension = ""
            if currentFileName == "Not Provided":
                doc_id = params["contentid"]
            try:
                fileProcessingTimeout = 30 # 30 seconds
                logFileOnTimeExceedTimer = None
                entityExtractionTimeout = ENTITY_EXTRACTION_TIMEOUT * 60
                with CvEETimeout.Timeout(entityExtractionTimeout):
                    logFileOnTimeExceedTimer = Timer(
                        fileProcessingTimeout,
                        logFilenameOnTimeExceed,
                        [doc_id, fileProcessingTimeout],
                    )
                    logFileOnTimeExceedTimer.start()
                    results = module.doAnalysis(text, params)
            except CvEETimeout.Timeout.MemoryLimitReached as e:
                LOGGER_Generic.error(
                    f"Document [{doc_id}] processing failed because of high memory usage."
                )
                results = {
                    "ErrorCode": CA_ERROR_CODES["MemoryLimitReached"],
                    "ErrorMessage": f"High memory usage for {task} entity extraction.",
                }
            except CvEETimeout.Timeout.ProcessingTimedOut as e:
                LOGGER_Generic.error(f"Extraction timed out for document [{doc_id}] .")
                results = {
                    "ErrorCode": CA_ERROR_CODES[task + "TimedOut"],
                    "ErrorMessage": f"{task} entity extraction request timed out.",
                }
            finally:
                if logFileOnTimeExceedTimer is not None:
                    logFileOnTimeExceedTimer.cancel()


            # Return output to write to ActiveMQ
            return results
        except Exception as e:
            if isinstance(e, ImportError):
                setRegValue("repairLibraryInstall", 1, type=int)
                return {
                    "ErrorCode": CA_ERROR_CODES["moduleLoadError"],
                    "ErrorMessage": "Failed to load module for task {}. Please restart the service to repair install. Exception {}".format(
                        task, e
                    ),
                }
            else:
                return {
                    "ErrorCode": CA_ERROR_CODES["moduleLoadError"],
                    "ErrorMessage": "Failed to process task {} : Exception {}".format(task, e),
                }

    def testClass(self):
        # Should print list of entities
        print(list(self.config["regpatterns"].keys()))
        # Should print dict of EE results with at least 1 item in email and 1 in phone
        print(self.doCA("qwefdf@gmail.com 8860854544"))


class ActiveMQReader(Thread, GenericMsgQueueCommunicator):
    """Class to connect to ActiveMQ and read messages, then pass it to GenericClass for EE"""

    def __init__(self, task, stompPort=8055, profiling=False):
        module_name = "ActiveMQReader"
        func_name = "__init__"
        func_string = "{}::{}() - ".format(module_name, func_name)
        Thread.__init__(self)
        GenericMsgQueueCommunicator.__init__(self)
        self.host = "127.0.0.1"
        self.stompPort = stompPort
        self.queue = task
        self.rQueue = "RESULT"
        self.clients = {
            self.queue: {"subscribe": True, "client": None},
            self.rQueue: {"subscribe": False, "client": None},
        }
        self.specialArgs = {}
        self.ACTIVEMQ_WAIT_TIMEOUT = 0
        self._stopped = Event()
        self.task = task
        self.PROFILE_SEND_INTERVAL = 30
        self.profiling = profiling
        self.prefetchSize = str(8)
        self.count_docs = 0
        self.pipe = False
        self.sp_ner_size_limit = 5 * 1024 * 1024
        self.sp_ner_pipe_min_size = 1 * 1024 * 1024
        self.sp_ner_use_nltk_tokenizer = True
        self.perfCounter = PerformanceCounter(LOGGER_Generic)
        self.perfPrintInterval = 300
        self.sendTimeout = 900000
        self.producerPrefetchSize = 1000
        self.consumerPrefetchSize = 1
        self.persistentEnabled = "false"
        # Timer(self.perfPrintInterval, self.printPerfStats).start()

    def refreshEntitiesCache(
        self,
        cacheToken,
        requiredEntities=[],
        dcPlanId=None,
        dcPolicyId=None,
        COMMSERVER_REACHABLE=None,
    ):
        module_name = "CvCAGenericClient"
        func_name = "refreshEntitiesCache"
        func_string = "{}::{}() - ".format(module_name, func_name)
        global ENTITIES_GLOBAL_CACHE
        try:
            ENTITIES_GLOBAL_CACHE[cacheToken]
            ENTITIES_GLOBAL_TIME_CACHE[cacheToken] = time.time()
            LOGGER_Generic.debug(
                "Job Token is already present in cache. Updating its timestamp and not querying the DB",
                func_string,
            )
        except:
            LOGGER_Generic.debug(
                "Job Token is not present in cache. About to query the DB", func_string
            )
            try:
                self.perfCounter.start_stopwatch("DBLoad")
                ENTITIES_GLOBAL_CACHE[cacheToken] = loadEntities(
                    requiredEntities,
                    dcPlanId=dcPlanId,
                    dcPolicyId=dcPolicyId,
                    logger=LOGGER_Generic,
                    COMMSERVER_REACHABLE=COMMSERVER_REACHABLE,
                    ee_cache_token=cacheToken,
                )
                self.perfCounter.stop_stopwatch("DBLoad")
                ENTITIES_GLOBAL_TIME_CACHE[cacheToken] = time.time()
            except Exception:
                raise

    def run(self):
        module_name = "CvCAGenericClient.ActiveMQReader"
        func_name = "run"
        func_string = "{}::{}() - ".format(module_name, func_name)
        global MEMORY_CHECK_EVENT, TASK_EVENT
        try:
            self.connectToQueue()
        except Exception as e:
            LOGGER_Generic.error(f"Error while connecting to ActiveMQ. Exception {e}", func_string)
        LOGGER_Generic.info(
            "Connected to tcp://{}:{} for task {}".format(self.host, self.stompPort, self.task),
            func_string,
        )

        while not TASK_EVENT[self.task].isSet():
            try:
                self.perfCounter.start_stopwatch(self.task + "Dequeue")
                frame = self.getFrame(self.clients[self.queue]["client"])
                if frame == None:
                    continue
                self.perfCounter.start_stopwatch(self.task + "_Time")
                frame_body = ujson.loads(frame)
                log_obj = {}
                log_obj["taskKey"] = frame_body["taskKey"]
                requiredKeys = ["entitiesToExtract", "contentid"]
                if "reqParamMap" in frame_body:
                    for reqObj in frame_body["reqParamMap"]:
                        if "attrKey" in reqObj and "attrValue" in reqObj:
                            if reqObj["attrKey"] in requiredKeys:
                                log_obj[reqObj["attrKey"]] = reqObj["attrValue"]
                LOGGER_Generic.debug(
                    "Received frame {} from {} Queue".format(log_obj, self.task), func_string
                )
                self.perfCounter.stop_stopwatch(self.task + "Dequeue")
                self.handleDocument(frame_body)
                self.perfCounter.stop_stopwatch(self.task + "_Time")
            except Exception as e:
                LOGGER_Generic.error(f"Connection Closed. Exception {e}", func_string)
                self.connected = False
                self.connectToQueue()
        LOGGER_Generic.info("Finished polling for analysis tasks. Exiting...", func_string)
        self.disconnectClients()        

    def startPerfStatPrinting(self):
        Timer(self.perfPrintInterval, self.printPerfStats).start()

    def printPerfStats(self):
        self.perfCounter.logStats()
        Timer(self.perfPrintInterval, self.printPerfStats).start()

    def handleDocument(self, message):
        module_name = "CvCAGenericClient"
        func_name = "handleDocument"
        func_string = "{}::{}() - ".format(module_name, func_name)
        global CA_ERROR_CODES, CURRENT_FILE_LIST, TASK_EVENT
        try:
            pickedUpTime = time.time() * 1000
            extractor = GenericCA()
            results = ""
            params = {}

            params["logger"] = LOGGER_Generic
            params["perfCounter"] = self.perfCounter

            taskOutput = {}
            timestamp = None

            if "taskInput" in message:
                attrMapList = message["taskInput"]["dataList"]["optTypedata"]
            if "outputFilePath" in message:
                params["outputFilePath"] = message["outputFilePath"]
            if "timestamp" in message:
                timestamp = message["timestamp"]

            if "reqParamMap" in message:
                for reqParamItem in message["reqParamMap"]:
                    if "attrKey" in reqParamItem and "attrValue" in reqParamItem:
                        params[reqParamItem["attrKey"]] = reqParamItem["attrValue"]

            # Get plan id and policy id

            dcPlanId = params.get("dcplanid", None)
            dcPolicyId = params.get("dcpolicyid", None)
            is_commserver_reachable = True

            for key in list(self.specialArgs.keys()):
                params[key] = self.specialArgs[key]

            if self.task in ["RER", "SP_NER", "DE", "DOC_TAGGER"]:
                requiredEntities = set()
                for task_key in ["RER", "SP_NER", "DE", "DOC_TAGGER", "ML"]:
                    if ("entitiesToExtract" + task_key) in params:
                        entities_list = [
                            x
                            for x in params["entitiesToExtract" + task_key].strip(",").split(",")
                            if len(x) > 0
                        ]
                        requiredEntities = requiredEntities.union(entities_list)
                requiredEntities = [int(x) for x in list(requiredEntities)]
                
                try:
                    if "eeCacheToken" not in params:
                        params["eeCacheToken"] = "default"
                        self.refreshEntitiesCache(
                            params["eeCacheToken"],
                            dcPlanId=dcPlanId,
                            dcPolicyId=dcPolicyId,
                            COMMSERVER_REACHABLE=params["commserver_reachable"],
                        )
                    else:
                        self.refreshEntitiesCache(
                            params["eeCacheToken"],
                            requiredEntities,
                            dcPlanId=dcPlanId,
                            dcPolicyId=dcPolicyId,
                            COMMSERVER_REACHABLE=params["commserver_reachable"],
                        )

                    params["entities_attributes"] = ENTITIES_GLOBAL_CACHE[
                        params["eeCacheToken"]
                    ]
                except Exception:
                    is_commserver_reachable = False
                

            taskOutput["opType"] = self.task
            taskOutput["opTypeName"] = self.task
            taskOutput["dataList"] = {}
            taskOutput["dataList"]["optTypedata"] = []

            if "taskKey" in message:
                params["taskKey"] = message["taskKey"]
            if "contentid" not in params:
                params["contentid"] = ""

            dataList = {}
            if (
                "taskInput" in message
                and "dataList" in message["taskInput"]
                and "optTypedata" in message["taskInput"]["dataList"]
            ):
                for dataListEntry in message["taskInput"]["dataList"]["optTypedata"]:
                    if "attrKey" in dataListEntry and "attrValue" in dataListEntry:
                        dataList[dataListEntry["attrKey"]] = dataListEntry["attrValue"]

            text_parsing_tasks = ["RER", "SP_NER", "DE", "DOC_TAGGER", "EMAIL_TAGGER"]
            parent_file = None

            CURRENT_FILE_LIST[self.task] = "Not Provided"
            if "FileName" in dataList:
                params["FileName"] = dataList["FileName"]
                CURRENT_FILE_LIST[self.task] = dataList["FileName"]

            if "FilePathForProcessing" in dataList and self.task in text_parsing_tasks:
                try:
                    if is_linux() == False and len(dataList["FilePathForProcessing"]) > 250:
                        # appending prefix \\?\ and making it unicode to allow long path which are greater than 256 characters
                        # there can be performance issue with this so avoiding it for shorter paths
                        # http://mynthon.net/howto/-/python/python%20-%20playground%20-%20playing%20with%20long%20paths%20on%20windows.txt
                        long_processing_path = r"\\?" + "\\" + dataList["FilePathForProcessing"]
                    else:
                        long_processing_path = dataList["FilePathForProcessing"]

                    params["FilePathForProcessing"] = long_processing_path
                    processingInput = ""
                    with open(long_processing_path.encode("utf-8"), "rU", encoding="utf-8") as fp:
                        if self.task in ["SP_NER", "DE", "DOC_TAGGER", "EMAIL_TAGGER"]:
                            processingInput = cleanRawTextForNER(fp.read())
                        elif self.task == "RER":
                            processingInput = fp.read()
                    parent_file = dataList["FilePathForProcessing"]
                except Exception as e:
                    LOGGER_Generic.exception(
                        "Process was unable to open file {} for contentid {} for task {}".format(
                            dataList["FilePathForProcessing"], params["contentid"], self.task
                        ),
                        func_string,
                    )
                    processingInput = ""
            elif "content" in dataList:
                processingInput = dataList["content"]
            elif "localFilePath" in message:
                processingInput = message["localFilePath"]
            else:
                processingInput = ""
            file_size = len(processingInput)

            printable_params = params.copy()
            non_printable_keys = ["entities_attributes"]
            for npk in non_printable_keys:
                if npk in printable_params:
                    del printable_params[npk]

            LOGGER_Generic.debug("Params is {}".format(printable_params), func_string)
            startTime = time.clock()
            if is_commserver_reachable and processingInput != "":
                results = extractor.doCA(self.task, processingInput, params)
            elif not is_commserver_reachable:
                results = {
                    "ErrorCode": CA_ERROR_CODES["CommserverNotReachable"],
                    "ErrorMessage": "Failed to get the entities from database. Please verify that Commserver is reachable.",
                }
            else:
                results = {
                    "ErrorCode": CA_ERROR_CODES["NoTextToProcess"],
                    "ErrorMessage": f"Missing search text json to process task {self.task}",
                }
            endTime = time.clock()

            if self.profiling:
                global PROFILER
                if self.task not in PROFILER:
                    PROFILER[self.task] = {}
                    PROFILER[self.task]["numTasks"] = 0
                    PROFILER[self.task]["QWaitTime"] = []
                    PROFILER[self.task]["ProcessingTime"] = []
                    PROFILER[self.task]["FileSize"] = []
                    PROFILER[self.task]["mimeType"] = ""
                PROFILER[self.task]["numTasks"] += 1
                PROFILER[self.task]["ProcessingTime"].append(endTime - startTime)
                if self.task == "RER":
                    PROFILER[self.task]["FileSize"].append(file_size)
                    PROFILER[self.task]["mimeType"] = "text/plain"
                if timestamp != None:
                    qWaitTime = (pickedUpTime - timestamp) / 1000
                    PROFILER[self.task]["QWaitTime"].append(qWaitTime)

            if results is not None:
                taskResponse = {}
                taskOutputForResultQueue = taskOutput.copy()

                taskResponse["attrKey"] = self.task + "Results"
                taskResponse["attrValue"] = ujson.dumps(results)

                if self.task in ["RER", "SP_NER", "DE", "DOC_TAGGER"]:

                    # This code is to only send those entities to the result queue
                    # which are selected by the user in the UI, but the others Should
                    # be sent to the next task (like DE) so that if there is any dependent
                    # entity to be captured, then it is not missed    
                    kill_current_process = False                    
                    if "ErrorCode" in results and results["ErrorCode"] in (
                        CA_ERROR_CODES["MemoryLimitReached"],
                        CA_ERROR_CODES[self.task + "TimedOut"],
                    ):
                        kill_current_process = True
                    selectedResults = {}
                    keysToIgnore = ["ErrorCode", "ErrorMessage"]
                    for key in list(results.keys()):
                        if (
                            (
                                "entities_attributes" in params
                                and params["entities_attributes"]["entities_selected"].get(key)
                                == True
                            )
                            or key.startswith("dt_")
                            or key.startswith("doc_tag")
                            or key == "skipFields"
                        ):
                            selectedResults[key] = results[key]

                    for key in keysToIgnore:
                        selectedResults[key] = results.get(key)

                    taskResponseForResultQueue = {}
                    taskResponseForResultQueue["attrKey"] = self.task + "Results"
                    taskResponseForResultQueue["attrValue"] = ujson.dumps(selectedResults)
                    taskOutputForResultQueue["dataList"]["optTypedata"].append(
                        taskResponseForResultQueue
                    )

                else:
                    taskOutput["dataList"]["optTypedata"].append(taskResponse)
                    taskOutputForResultQueue = taskOutput.copy()

            contentid = ""
            if "contentid" in params:
                contentid = params["contentid"]
            task = self.task
            if task == "RER" and CA_TASK_MODULE_MAP[task] == "CvCABTRERClient":
                task = "Bitext Extraction"
            if (
                "ErrorCode" in results
                and results["ErrorCode"] != CA_ERROR_CODES["success"]
                and "ErrorMessage" in results
            ):
                LOGGER_Generic.exception(
                    "There was an error while performing task {} for contentid {} : {}".format(
                        task, contentid, results["ErrorMessage"]
                    ),
                    func_string,
                )
            else:
                LOGGER_Generic.debug(
                    "Finished task {} for contentid {}".format(task, contentid), func_string
                )

            if "taskChain" in message:
                taskChain = ujson.loads(message["taskChain"])
                if self.task in taskChain and taskChain[self.task] is not "":
                    for q in taskChain[self.task]:
                        new_msg = message.copy()
                        new_msg["timestamp"] = int(time.time() * 1000)
                        if q == "RESULT":
                            new_msg["taskOutput"] = taskOutputForResultQueue
                        elif is_commserver_reachable:
                            if "ExtractorTempLocation" not in params:
                                LOGGER_Generic.error(
                                    "Unable to find ExtractorTempLocation for contentid {} for task {}. Not writing result to a file.".format(
                                        params["contentid"], self.task
                                    ),
                                    func_string,
                                )
                                new_msg["taskInput"] = taskOutput
                            else:
                                try:
                                    output_file = "CvAnalyzer_{}.tmp".format(time.time())
                                    if (
                                        is_linux() == False
                                        and len(params["ExtractorTempLocation"]) > 200
                                    ):
                                        # appending prefix \\?\ and making it unicode to allow long path which are greater than 256 characters
                                        # there can be performance issue with this so avoiding it for shorter paths
                                        # http://mynthon.net/howto/-/python/python%20-%20playground%20-%20playing%20with%20long%20paths%20on%20windows.txt
                                        long_processing_path = (
                                            r"\\?" + "\\" + params["ExtractorTempLocation"]
                                        )
                                    else:
                                        long_processing_path = params["ExtractorTempLocation"]
                                    output_file = os.path.join(long_processing_path, output_file)
                                    with open(output_file, "wb") as fp:
                                        fp.write(ujson.dumps(results).encode("utf-8"))
                                    found = False
                                    if (
                                        "taskInput" in new_msg
                                        and "dataList" in new_msg["taskInput"]
                                        and "optTypedata" in new_msg["taskInput"]["dataList"]
                                    ):
                                        for reqParamItem in new_msg["taskInput"]["dataList"][
                                            "optTypedata"
                                        ]:
                                            if (
                                                "attrKey" in reqParamItem
                                                and reqParamItem["attrKey"]
                                                == "FilePathForProcessing"
                                            ):
                                                reqParamItem["attrValue"] = output_file
                                                found = True
                                                break
                                    if not found:
                                        if "taskInput" not in new_msg:
                                            new_msg["taskInput"] = {}
                                        if "dataList" not in new_msg["taskInput"]:
                                            new_msg["taskInput"]["dataList"] = {}
                                        if "optTypedata" not in new_msg["taskInput"]["dataList"]:
                                            new_msg["taskInput"]["dataList"]["optTypedata"] = []
                                        new_msg["taskInput"]["dataList"]["optTypedata"].append(
                                            {
                                                "attrKey": "FilePathForProcessing",
                                                "attrValue": output_file,
                                            }
                                        )
                                    if q == "DE":
                                        if parent_file != None:
                                            new_msg["reqParamMap"].append(
                                                {
                                                    "attrKey": "parent_tmp_file",
                                                    "attrValue": parent_file,
                                                }
                                            )
                                        else:
                                            new_msg["reqParamMap"].append(
                                                {"attrKey": "content", "attrValue": processingInput}
                                            )
                                except Exception as e:
                                    LOGGER_Generic.exception(
                                        "Unable to write results to file. Sending it in response object",
                                        func_string,
                                    )
                                    new_msg["taskOutput"] = taskOutput
                        # Sent request to face representation queue after getting face detection response
                        if self.task == "FACE_DETECTION":
                            frRequest = {}
                            from CvCAFaceDetect import getMongoConnStr

                            mongoConStr = getMongoConnStr()
                            if mongoConStr is not None:
                                encodedStr = base64.b64encode(mongoConStr)
                                frRequest["mongoConnection"] = encodedStr
                            frRequest["text"] = results
                            self.sendOnConnect(
                                FRqueue,
                                ujson.dumps(frRequest),
                                client=self.clients[self.rQueue]["client"],
                                headers={
                                    "persistent": self.persistentEnabled,
                                    "correlation-id": new_msg["taskKey"],
                                    "sendTimeout": self.sendTimeout,
                                    "prefetchSize": self.producerPrefetchSize,
                                },
                            )
                            LOGGER_Generic.info(
                                "Face representation request sent to queue: {}".format(FRqueue),
                                func_string,
                            )
                        self.sendOnConnect(
                            q.encode(),
                            ujson.dumps(new_msg).encode(),
                            client=self.clients[self.rQueue]["client"],
                            headers={
                                "persistent": self.persistentEnabled,
                                "correlation-id": new_msg["taskKey"],
                                "sendTimeout": self.sendTimeout,
                                "prefetchSize": self.producerPrefetchSize,
                            },
                        )

            if self.task == "SP_NER":
                self.count_docs += 1
                if self.count_docs > 1000:
                    gc.collect()
                    self.count_docs = 0

            del processingInput

            if kill_current_process is True:
                TASK_EVENT[self.task].set()

        except Exception as e:
            LOGGER_Generic.exception(f"Couldn't send response to ActiveMQ. Exception {e}", func_string)

    def testClass(self):
        # Should print EE results for text passed as argument
        # self.handleDocument("Hello 473848738 293892")
        # Should publish 3 messages to Results Queue in ActiveMQ
        self.run()


class ProfilingClient(Thread, GenericMsgQueueCommunicator):
    # This class is to collect the messages in the PROFILER
    # object and send to AMQ Profiler queue

    def __init__(self, stompPort=61650):
        Thread.__init__(self)
        GenericMsgQueueCommunicator.__init__(self)
        self.host = "127.0.0.1"
        self.stompPort = stompPort
        self.queue = "PROFILING"
        self.ACTIVEMQ_WAIT_TIMEOUT = 0
        self.PROFILE_SEND_INTERVAL = 30
        self.ACTIVEMQ_SEND_TIMEOUT = 900000
        self.ACTIVEMQ_PERSISTENT_ENABLED = "false"
        self.producerPrefetchSize = 1000
        self.consumerPrefetchSize = 1
        self.clients = {self.queue: {"subscribe": False, "client": None}}
        self._stopped = Event()

    def run(self):
        module_name = "CvCAGenericClient.ProfilingClient"
        func_name = "run"
        func_string = "{}::{}() - ".format(module_name, func_name)
        global PROCESS_ID
        try:
            self.connectToQueue()
        except Exception as e:
            LOGGER_Generic.error(f"Error while connecting to ActiveMQ. Exception {e}", func_string)
        while not self._stopped.wait(self.ACTIVEMQ_WAIT_TIMEOUT):
            try:
                response = self.profileTasks()
                if len(response) > 0:

                    for entry in response:
                        taskMetrics = entry.copy()
                        taskMetrics["source"] = "Python"
                        taskMetrics["processId"] = str(PROCESS_ID)

                        self.sendOnConnect(
                            self.queue,
                            ujson.dumps(taskMetrics),
                            headers={
                                "persistent": self.ACTIVEMQ_PERSISTENT_ENABLED,
                                "sendTimeout": self.ACTIVEMQ_SEND_TIMEOUT,
                                "prefetchSize": self.producerPrefetchSize,
                            },
                        )
                        LOGGER_Generic.verbose(
                            "Message {} sent to Profiling queue: {}".format(
                                taskMetrics, self.queue
                            ),
                            func_string,
                        )

                time.sleep(self.PROFILE_SEND_INTERVAL)
            except Exception as e:
                LOGGER_Generic.error(f"Connection Closed. Exception {e}", func_string)
                self.connected = False
                self.connectToQueue()
        LOGGER_Generic.info("Finished polling for analysis tasks. Exiting...", func_string)

    def profileTasks(self):
        module_name = "CvCAGenericClient"
        func_name = "profileTasks"
        func_string = "{}::{}() - ".format(module_name, func_name)

        global PROFILER
        response = []
        avgAttributes = {
            "ProcessingTime": "taskAverageProcessingTime",
            "FileSize": "averageFileSize",
            "QWaitTime": "qAverageWaitTime",
        }
        try:
            for task in PROFILER:
                if PROFILER[task]["numTasks"] > 0:
                    numTasks = PROFILER[task]["numTasks"]
                    mimeType = PROFILER[task]["mimeType"]
                    taskObj = {"opTypeName": task, "numTasks": numTasks, "mimeType": mimeType}

                    # Get first n values from these lists, then delete
                    # those values from the global list so that there
                    # is no redundancy
                    for key, val in list(avgAttributes.items()):
                        if key in PROFILER[task]:
                            values = PROFILER[task][key][:numTasks]
                            del PROFILER[task][key][:numTasks]
                            if len(values) > 0:
                                avg_value = sum(values) / float(len(values))
                                taskObj[val] = avg_value

                    # Append combined dict to response array
                    response.append(taskObj)
                    PROFILER[task]["numTasks"] -= numTasks
        except Exception as e:
            LOGGER_Generic.error(
                f"There was an error while sending profiled messages to ActiveMQ. Exception {e}",
                func_string,
            )
        return response


MEMORY_CHECK_EVENT = Event()
TASK_EVENT = { "RER" : Event(), "SP_NER" : Event(), "DOC_TAGGER": Event(), "DE": Event(), "EMAIL_TAGGER": Event()}
TASK_CONSUMERS = {}
KILL_TIMEOUT = 120
MEMORY_TIME_CHECK = 30
MEMORY_LIMIT = 3.0
GC_TIMER = 120
PROCESS_RESTART_TIME = 20 * 60

def timedMemoryCheck():
    module_name = "CvCAGenericClient"
    func_name = "timedMemoryCheck"
    func_string = "{}::{}() - ".format(module_name, func_name)
    global MEMORY_TIME_CHECK, MEMORY_CHECK_EVENT, KILL_TIMEOUT, MEMORY_LIMIT, CURRENT_FILE_LIST
    pid = os.getpid()
    process = psutil.Process(pid)
    memory = ((process.memory_info()[0] / float(2 ** 40)) * 1000) + 0.2
    if memory > MEMORY_LIMIT:
        LOGGER_Generic.info(
            "Memory for process has exceeded {}GB. Setting kill signal for process. If processing is not finished by {} seconds, process will be killed. List of files being processed currently are: {}".format(
                MEMORY_LIMIT, KILL_TIMEOUT, CURRENT_FILE_LIST
            )
        )
        MEMORY_CHECK_EVENT.set()
        Timer(KILL_TIMEOUT, killProcessAndChildren, (pid,)).start()
        return
        # gc.collect()
    Timer(MEMORY_TIME_CHECK, timedMemoryCheck).start()


def timedGC():
    global GC_TIMER
    gc.collect()
    Timer(GC_TIMER, timedGC).start()


def timed_recycle_process():
    Timer(PROCESS_RESTART_TIME, kill_process).start()


def kill_process():
    global MEMORY_CHECK_EVENT
    MEMORY_CHECK_EVENT.set()


def timedEntitiesCachePrune(cacheCleanInterval):
    module_name = "CvCAGenericClient"
    func_name = "timedEntitiesCachePrune"
    func_string = "{}::{}() - ".format(module_name, func_name)
    global ENTITIES_GLOBAL_CACHE, ENTITIES_GLOBAL_TIME_CACHE
    for cacheToken in list(ENTITIES_GLOBAL_TIME_CACHE.keys()):
        if (time.time() - ENTITIES_GLOBAL_TIME_CACHE[cacheToken]) > cacheCleanInterval:
            LOGGER_Generic.debug(
                "Job token {} has not been used for a long time. Clearing cache for it".format(
                    cacheToken
                ),
                func_string,
            )
            del ENTITIES_GLOBAL_TIME_CACHE[cacheToken]
            del ENTITIES_GLOBAL_CACHE[cacheToken]
    Timer(cacheCleanInterval, timedEntitiesCachePrune, (cacheCleanInterval,)).start()

@lru_cache()
def get_default_config():
    try:
        global LOGGER_Generic, CA_TASK_LIST, ROTATING_MAX_BYTES, ROTATING_BACKUP_COUNT, LOG_LEVELS, CLIENT_ID
        global MEMORY_CHECK_EVENT, MEMORY_TIME_CHECK, KILL_TIMEOUT, MEMORY_LIMIT, GC_TIMER, ENTITY_EXTRACTION_TIMEOUT, DEFAULT_CONFIG
        # if 'enable_remote_debug' in conf and conf['enable_remote_debug'] == True:
        #     ptvsd.enable_attach("vincicode")
        stompPort = 8055
        log_level = 1
        profiling = False
        pre_process_text = True
        low_perf = False
        CLIENT_ID = id
        cacheCleanInterval = 30 * 60
        sendTimeout = 900000
        producerPrefetchSize = 1000
        consumerPrefetchSize = 1
        persistentEnabled = False
        conf = DEFAULT_CONFIG
        if conf is not None:
            stompPort = conf["stompPort"]["value"]
            ROTATING_MAX_BYTES = conf["log_max_bytes"]["value"]
            ROTATING_BACKUP_COUNT = conf["log_backup_count"]["value"]
            log_level = conf["log_level"]["value"]
            profiling = conf["profiling"]["value"]
            pre_process_text = conf["pre_process_text"]["value"]
            low_perf = conf["low_perf"]["value"]
            MEMORY_LIMIT = conf["generic_memory_limit"]["value"]
            MEMORY_TIME_CHECK = conf["memory_time_check"]["value"]
            KILL_TIMEOUT = conf["generic_task_kill_timeout"]["value"]
            GC_TIMER = conf["generic_task_gc_timer"]["value"]
            childDocs = conf["child_docs"]["value"]
            cacheCleanInterval = int(conf["cache_clean_interval"]["value"]) * 60
            sendTimeout = conf["activemq_send_timeout"]["value"]
            persistentEnabled = conf["activemq_persistent_enabled"]["value"]
            producerPrefetchSize = conf["activemq_producer_prefetch_size"]["value"]
            consumerPrefetchSize = conf["activemq_consumer_prefetch_size"]["value"]
            ENTITY_EXTRACTION_TIMEOUT = conf["entity_extraction_timeout"]["value"]
            if conf["use_btrer"]["value"]:
                CA_TASK_MODULE_MAP["RER"] = "CvCABTRERClient"
            params = {
                "conf": conf,
                "stompPort" : stompPort,
                "profiling": profiling,
                "childDocs": childDocs,                
                "sendTimeout": sendTimeout,
                "producerPrefetchSize": producerPrefetchSize,
                "consumerPrefetchSize": consumerPrefetchSize,
                "persistentEnabled": persistentEnabled,
                "pre_process_text": pre_process_text,
                "low_perf": low_perf,      
                "cacheCleanInterval": cacheCleanInterval,
                "profiling": profiling,
            }
            return params
    except Exception as e:
        raise
    
def checkAndRestartConsumers():
    global TASK_CONSUMERS, TASK_EVENT
    try:        
        for task, thread_ in TASK_CONSUMERS.items():
            if thread_ is None or not thread_.is_alive():
                LOGGER_Generic.error(f"Starting consumer task {task}")
                start_task_consumer(task)     
                TASK_EVENT[task].clear()
        Timer(20, checkAndRestartConsumers).start()     
    except Exception as e:
        raise

def doProcessing(
    cmdQ=None,
    conf=None,
    id=1,
    parent_pid=0,
    is_dummy_process=False,
    shared_tasks=None,
    commserver_reachable=None,
):
    module_name = "CvCAGenericClient"
    func_name = "doProcessing"
    func_string = "{}::{}() - ".format(module_name, func_name)
    global DEFAULT_CONFIG, LOGGER_Generic
    DEFAULT_CONFIG = conf
    DEFAULT_CONFIG["commserver_reachable"] = commserver_reachable
    DEFAULT_CONFIG["is_dummy_process"] = is_dummy_process
    params = get_default_config()
    try:
        logger_options = {
            "ROTATING_BACKUP_COUNT": ROTATING_BACKUP_COUNT,
            "ROTATING_MAX_BYTES": ROTATING_MAX_BYTES,
        }

        LOGGER_Generic = get_logger_handler(
            os.path.join(getBaseDir(), HELPER_DLL), "ContentAnalyzer", logger_options
        )

    except Exception as e:
        print("Error while initialising: {}".format(e))

    checkParentAndKill(parent_pid, PROCESS_ID)    
    # timedMemoryCheck()
    # no need to do memory check from here as we have moved it to NER process itself
    # where process will send the proper status to the result queue first and then will kill itself
    # timed_recycle_process()
    timedGC()
    timedEntitiesCachePrune(params["cacheCleanInterval"])
    
    if params["profiling"]:
        profiler = ProfilingClient(params["stompPort"])
        profiler.ACTIVEMQ_SEND_TIMEOUT = params["sendTimeout"]
        profiler.ACTIVEMQ_PERSISTENT_ENABLED = "true" if params["persistentEnabled"] else "false"
        profiler.start()

    try:
        for task in CA_TASK_LIST:
            if task in shared_tasks:
                start_task_consumer(task)
        checkAndRestartConsumers()
        while True:
            if cmdQ is not None:
                command = cmdQ.get()
                if command == "stop":
                    LOGGER_Generic.info("Stop command received. Exiting.", func_string)
                    break
            else:
                break
    except Exception as e:
        LOGGER_Generic.error(f"Could not start reading ActiveMQ. Exception {e}", func_string)

def start_task_consumer(task):    
    try:              
        global LOGGER_Generic, TASK_CONSUMERS
        params = get_default_config()
        conf = params["conf"]
        
        CURRENT_FILE_LIST[task] = ""
        TASK_CONSUMERS[task] = task_thread = ActiveMQReader(task, params["stompPort"], profiling=params["profiling"])
        task_thread.specialArgs["childDocs"] = params["childDocs"]
        task_thread.specialArgs["commserver_reachable"] = conf["commserver_reachable"]
        task_thread.sendTimeout = params["sendTimeout"]
        task_thread.producerPrefetchSize = params["producerPrefetchSize"]
        task_thread.prefetchSize = str(params["consumerPrefetchSize"])
        task_thread.persistentEnabled = "true" if params["persistentEnabled"] else "false"
        if "generic_perf_stat_print_interval" in conf:
            task_thread.perfPrintInterval = conf["generic_perf_stat_print_interval"][
                "value"
            ]
        if "generic_perf_stat" in conf and conf["generic_perf_stat"]["value"]:
            task_thread.startPerfStatPrinting()
        if task is "OCR" and "ocr_threshold" in conf:
            task_thread.specialArgs["OCRThreshold"] = conf["ocr_threshold"]["value"]
        if task is "RER":
            task_thread.specialArgs["pre_process_text"] = params["pre_process_text"]
            task_thread.specialArgs["low_perf"] = params["low_perf"]
            if "use_re2" in conf:
                task_thread.specialArgs["use_re2"] = conf["use_re2"]["value"]
            if "deep_validate" in conf:
                task_thread.specialArgs["deep_validate"] = conf["deep_validate"]["value"]
            if "proximity_conf" in conf:
                task_thread.specialArgs["proximity_conf"] = ujson.loads(
                    conf["proximity_conf"]["value"]
                )
            if "extract_all_date_formats" in conf:
                task_thread.specialArgs["extract_all_date_formats"] = conf["extract_all_date_formats"]["value"]
            if "phone_entity_leniency" in conf:
                task_thread.specialArgs["phone_entity_leniency"] = conf[
                    "phone_entity_leniency"
                ]["value"]
        if task == "DOC_TAGGER":
            start_subtask_processes(logger=LOGGER_Generic)
        if task == "DE":
            if "proximity_conf" in conf:
                task_thread.specialArgs["proximity_conf"] = ujson.loads(
                    conf["proximity_conf"]["value"]
                )
        if task is "SP_NER":
            task_thread.specialArgs["is_dummy_process"] = conf["is_dummy_process"]
            if "sp_ner_pipe" in conf:
                task_thread.specialArgs["pipe"] = conf["sp_ner_pipe"]["value"]
            if "sp_ner_size_limit" in conf:
                task_thread.specialArgs["sp_ner_size_limit"] = conf["sp_ner_size_limit"][
                    "value"
                ]
            if "sp_ner_pipe_min_size" in conf:
                task_thread.specialArgs["sp_ner_pipe_min_size"] = conf[
                    "sp_ner_pipe_min_size"
                ]["value"]
            if "sp_ner_use_nltk_tokenizer" in conf:
                task_thread.specialArgs["sp_ner_use_nltk_tokenizer"] = conf[
                    "sp_ner_use_nltk_tokenizer"
                ]["value"]
            if "sp_ner_usecustompipeline" in conf:
                task_thread.specialArgs["sp_ner_usecustompipeline"] = conf[
                    "sp_ner_usecustompipeline"
                ]["value"]
            if "sp_ner_custompipeline" in conf:
                task_thread.specialArgs["sp_ner_custompipeline"] = conf[
                    "sp_ner_custompipeline"
                ]["value"]
            if "sp_ner_probcutoff" in conf:
                task_thread.specialArgs["sp_ner_probcutoff"] = conf["sp_ner_probcutoff"][
                    "value"
                ]
            if "use_document_classifier" in conf:
                task_thread.specialArgs["use_document_classifier"] = conf[
                    "use_document_classifier"
                ]["value"]
            if "sp_ner_structure_documents_handling" in conf:
                task_thread.specialArgs["sp_ner_structure_documents_handling"] = conf[
                    "sp_ner_structure_documents_handling"
                ]["value"]
        task_thread.start()
    except Exception as e:
        LOGGER_Generic.error(f"Failed to start task {task}. Exception {e}")

if __name__ == "__main__":
    # c = GenericCA()
    # c.testClass()
    # a = ActiveMQReader()
    # a.testClass()
    # LOGGER_Generic = logging.getLogger(__name__)
    from CvEEConfigHelper import REG_CONF
    from multiprocessing import Value

    commserver_reachable = Value("i", 1)
    doProcessing(
        conf=REG_CONF, shared_tasks=["EMAIL_TAGGER"], commserver_reachable=commserver_reachable
    )
