import os
import queue
import threading

from multiprocessing import Queue, Process
from pathlib import Path

from CvEEConfigHelper import (
    CA_SUBTASK_MODULE_MAP,
    CA_ERROR_CODES,
    checkParentAndKill,
    delete_dict_values,
    loadRegValue,
    update_training_status,
)

PROCESS_ID = os.getpid()

PARENT_TASK = "DOC_TAGGER"
GENERIC_CLIENTS_CHECK_TIMER = 1 * 60
module_name = "cvee_classification_task"
child_queues = None
process_index_task_mapping = {}
process_task_index_mapping = {}
sub_processes = None
parent_queue = Queue()
LOGGER_Generic = None
training_cancelled = set()
last_train_entity = None


def spawn_child_process(process_index, sub_task=None):
    func_name = "spawn_child_process"
    func_str = f"{module_name}::{func_name}() - "

    global child_queues, process_task_index_mapping, sub_processes

    if sub_task is None:
        return
    try:
        module = __import__(CA_SUBTASK_MODULE_MAP[PARENT_TASK][sub_task])
        child_queues[process_index] = Queue()
        params = {
            "child_queue": child_queues[process_index],
            "parent_queue": parent_queue,
            "id": process_index + 1,
            "parent_pid": PROCESS_ID,
        }
        process_index_task_mapping[process_index] = sub_task
        process_task_index_mapping[sub_task] = process_index
        sub_processes[process_index] = Process(target=module.doProcessing, kwargs=params)
        if sub_task != "TRAIN":
            sub_processes[process_index].daemon = True
        sub_processes[process_index].start()
        LOGGER_Generic.info(
            f"Started sub task {sub_task} with process id {sub_processes[process_index].pid}",
            func_str,
        )
    except Exception as e:
        LOGGER_Generic.exception(f"Exception occured. {e}", func_str)
        raise




def start_subtask_processes(logger=None):
    """
        this is only valid for DOC_TAGGER task
        in case of DOC_TAGGER task, it will behave as a parent process and will receive messages from the queue
        we will start separate child processes to do actual processing
    """
    global child_queues, sub_processes, LOGGER_Generic

    func_name = "start_subtask_processes"
    func_str = f"{module_name}::{func_name}() - "

    sub_tasks = ["TRAIN", "TEST"]
    LOGGER_Generic = logger
    LOGGER_Generic.info("Starting sub tasks for DOC_TAGGER", func_str)
    child_queues = [None] * len(sub_tasks)
    sub_processes = [None] * len(sub_tasks)
    try:
        for task_index, task in enumerate(sub_tasks):
            spawn_child_process(task_index, task)
        checkParentAndKill(os.getppid(), os.getpid())
        checkAndRestartChildProcesses()
        # start mlflow server in an additional process
        # start_mlflow_server()
    except Exception as e:
        LOGGER_Generic.exception(
            f"Failed to spawn child processes for DOC_TAGGER task. Exception {e}", func_str
        )


def checkAndRestartChildProcesses():
    try:
        for process_index, sub_processs in enumerate(sub_processes):
            if sub_processes is not None and not sub_processs.is_alive():
                spawn_child_process(process_index, process_index_task_mapping[process_index])
    except Exception:
        raise
    threading.Timer(GENERIC_CLIENTS_CHECK_TIMER, checkAndRestartChildProcesses).start()


def preProcess(params=None):
    global child_queues, parent_queue

    func_name = "preProcess"
    func_str = f"{module_name}::{func_name}() - "

    try:
        default_response = {"ErrorCode": CA_ERROR_CODES["success"]}
        if "sub_task" in params and params["sub_task"].upper() == "CANCEL_TRAINING":
            return default_response
        if "sub_task" in params and params["sub_task"].strip() != "":
            sub_task = params["sub_task"]
            process_index = process_task_index_mapping[sub_task]
            if "logger" in params:
                del params["logger"]
            if "perfCounter" in params:
                del params["perfCounter"]
            queue_command = {"operation": "preProcess", "params": params}
            LOGGER_Generic.debug(
                f"Sending queue command {queue_command['operation']} for sub task {sub_task}",
                func_str,
            )
            child_queues[process_index].put(queue_command)

            response = parent_queue.get(block=True)
            del queue_command
            return response
        else:
            return default_response
    except Exception as e:
        LOGGER_Generic.exception(f"Exception occured {e}", func_str)


def cancel_training_opertaion(params):
    """
    check if this cancel_training operation,
    if cancelled entity is in training, 
        restart the train process        
    else if cancelled entity in training_cancelled set
        remove entity from training_cancelled set
    else
        add entityId in training_cancelled set
        for each subsequent train requests from the queue
            if entity is in training_cancelled
                update entity status as training_cancelled
                return immediately
    update entity status as training_cancelled
    """
    from cvee_get_entities import ModelTrainingStatus

    response = {"ErrorCode": CA_ERROR_CODES["success"], "ErrorMessage": ""}
    try:
        entity_id_str = str(params["entityId"])
    except Exception as e:
        error_message = "Entity Id is missing in message request for cancel training."
        response["ErrorCode"] = CA_ERROR_CODES["CancelTrainingFailed"]
        response["ErrorMessage"] = error_message
        LOGGER_Generic.error(error_message)
        return response
    try:
        if entity_id_str == str(last_train_entity):
            sub_processes[process_task_index_mapping["TRAIN"]].kill()
        elif entity_id_str in training_cancelled:
            training_cancelled.remove(entity_id_str)
        else:
            training_cancelled.add(entity_id_str)
        additional_attributes = {
            "errLogMessage": "Classifier training was cancelled by user.",
            "errorCode": CA_ERROR_CODES["TrainCancelByUser"],
        }
        update_training_status(
            entity_id_str,
            "",
            training_status=ModelTrainingStatus.TRAINING_COMPLETED_WITH_ERROR.value,
            additional_attributes=additional_attributes,
        )
        LOGGER_Generic.info(f"Training cancelled for entity {entity_id_str}.")
        return response
    except Exception as e:
        error_message = f"Cancel training failed for entity [{params['entityId']}]."
        LOGGER_Generic.exception(f"{error_message}. Exception {e}")
        response["ErrorCode"] = CA_ERROR_CODES["CancelTrainingFailed"]
        response["ErrorMessage"] = error_message
        return response


def doAnalysis(processing_input, params={}):
    global child_queues, parent_queue, last_train_entity

    func_name = "doAnalysis"
    func_str = f"{module_name}::{func_name}() - "

    try:
        default_response = {"ErrorCode": CA_ERROR_CODES["success"], "ErrorMessage": None}
        if "sub_task" in params and params["sub_task"].strip() != "":

            """
                check if this cancel_training operation,
                if cancelled entity is in training, 
                    restart the train process, 
                    update entity status as training_cancelled
                else
                    add entityId in training_cancelled set
                    for each subsequent train requests from the queue
                        if entity is in training_cancelled
                            update entity status as training_cancelled
                            return immediately
            """
            if "sub_task" in params and params["sub_task"].upper() == "CANCEL_TRAINING":                
                return cancel_training_opertaion(params)
            else:
                sub_task = params["sub_task"]
                process_index = process_task_index_mapping[sub_task]

                """
                    logger is having a cpp dll reference that we can't serialize to send over the queue
                    perfCounter is having logger instance in its class
                    commserver_reachable is an synchronized event value that is not serializable
                    entities_attributes is really large with all the regex of the entities, avoiding it for the queue                
                """
                params_copy = params.copy()
                delete_dict_values(
                    params_copy, ["logger", "perfCounter", "commserver_reachable", "entities_attributes"]
                )
                queue_command = {"operation": "doAnalysis", "params": params_copy}

                child_queues[process_index].put(queue_command)
                del params_copy
                if sub_task == "TRAIN":
                    last_train_entity = params["entityId"]
                LOGGER_Generic.debug(
                    f"Sending queue command {queue_command['operation']} for sub task {sub_task}",
                    func_str,
                )

                if sub_task != "TRAIN":
                    try:
                        response = parent_queue.get(block=True, timeout=1000)  # 10 minutes timeout
                    except queue.Empty:
                        LOGGER_Generic.error(
                            f"Timed out waiting for sub task {sub_task} response", func_str
                        )
                        response = default_response
                    return response
                else:
                    return default_response
        else:
            return default_response
    except AttributeError as e:
        LOGGER_Generic.exception(f"Failed to pickle {e}.", func_str)
    except Exception as e:
        LOGGER_Generic.exception(f"Exception occured {e}.", func_str)
