import os
import importlib.util
import pickle

from pathlib import Path
from urllib.request import url2pathname
from zipfile import ZipFile

import pandas as pd
import ujson

from mlflow.entities import ViewType
from mlflow.tracking.fluent import search_runs
from mlflow import (
    get_experiment_by_name,
    create_experiment,
    set_tracking_uri,
    start_run,
    get_artifact_uri,
)

from cvee_extract_features import get_features
from cvee_preprocess_data import DataPreprocessor
from CvCAPerfCounters import PerformanceCounter
from cvee_get_entities import get_entity_wrapper, ModelTrainingStatus
from CvEEConfigHelper import (
    CA_ERROR_CODES,
    loadRegValue,
    checkParentAndKill,
    get_dll_logger,
    getPythonSitePackagesDir,
    cleanRawTextForNER,
    load_model_config,
    generic_cache,
    ModelNotTrained,
    ModelDoesNotExist,
    EntityDoesNotExist,
)

module_name = "cvee_one_class_classifier"
MODELS = None
PREPROCESSORS = None
SYSTEM_MODELS_DIR = "cv_doc_tag_models"

logging = get_dll_logger()


def get_default_model_location(entity_key):
    try:
        classifier_models_dir = "../ContentAnalyzer/bin/classifier_models"
        models_dir_path = Path(classifier_models_dir) / Path(SYSTEM_MODELS_DIR)
        entity_model_location = models_dir_path / Path(entity_key) / Path("model")
        if entity_model_location.exists():
            return str(entity_model_location)
        else:
            raise Exception(
                f"Model location {str(entity_model_location)} does not exists for entity [{entity_key}]"
            )
    except Exception:
        raise


# def load_models():
#     global MODELS, PREPROCESSORS
#     try:
#         if MODELS is None:
#             classifier_models_dir = "../ContentAnalyzer/bin/classifier_models"
#             models_dir_path = Path(classifier_models_dir) / Path(MODELS_DIR)
#             MODELS = []
#             TAGS = loadRegValue("sDefaultCategories", "finance,legal,technical", type=str)
#             TAGS = [tag_.strip() for tag_ in TAGS.strip(",").split(",")]
#             for tag in TAGS:
#                 tag_model_path = models_dir_path / Path(tag + "/model")
#                 model = None
#                 transformer = None
#                 with open(tag_model_path / Path("model.pkl"), "rb") as model_file:
#                     model = pickle.load(model_file)
#                 with open(tag_model_path / Path("transformer.pkl"), "rb") as transformer_file:
#                     transformer = pickle.load(transformer_file)
#                 MODELS.append((tag, model, transformer))
#             PREPROCESSORS = DataPreprocessor()
#         # return MODELS, PREPROCESSORS
#     except Exception:
#         raise


def get_text(processing_input, params):
    try:
        jsonified_input = ujson.loads(processing_input)
        clean_input = ""
        if "EntityExtractionFields" in params:
            if type(params["EntityExtractionFields"]) == str:
                params["EntityExtractionFields"] = (
                    params["EntityExtractionFields"].strip(",").split(",")
                )
        else:
            params["EntityExtractionFields"] = ["content"]
        if "dataList" in jsonified_input and "optTypedata" in jsonified_input["dataList"]:
            for item in jsonified_input["dataList"]["optTypedata"]:
                if (
                    "attrKey" in item
                    and "attrValue" in item
                    and item["attrKey"] in params["EntityExtractionFields"]
                ):
                    clean_input += (
                        f"{item['attrKey']}{os.linesep}" if item["attrKey"] != "content" else ""
                    ) + f"{item['attrValue']}{os.linesep}"
        clean_input = clean_input.strip()
        return clean_input
    except Exception as e:
        raise Exception(f"Failed to get the text for categorization. Exception {e}")


def is_trained(entityDetails):
    try:
        if int(entityDetails.entityXML.isSystemDefinedEntity) == 1:
            return True
        return int(entityDetails.entityXML.classifierDetails.trainingStatus) == int(
            ModelTrainingStatus.TRAINING_COMPLETED.value
        )
    except Exception:
        return False


def create_new_run(experiment_id, entityDetails, params):
    try:
        if f"download_location_{entityDetails.entityKey}" not in params:
            raise ModelDoesNotExist(
                f"Sync download path was not available for entity [{entityDetails.entityName}]."
            )
        download_path = Path(params[f"download_location_{entityDetails.entityKey}"])
        if not download_path.exists():
            raise ModelDoesNotExist(
                f"Sync download path does not exists for entity [{entityDetails.entityName}]. Download path {str(download_path)}"
            )
        # create a new run
        with start_run(experiment_id=experiment_id):
            # get zip file location from params
            # extract zip file content in current artifact location
            # return current artifact location
            model_location = Path(url2pathname(get_artifact_uri()[7:])).relative_to(Path.cwd())
            # download_location_<entitykey>
            with ZipFile(download_path, "r") as zip_:
                zip_.extractall(path=str(model_location))
            try:
                # TODO: in case of multiple DOC_TAGGER test process, this will be an issue
                download_path.unlink()
            except:
                pass
            return model_location
    except ModelDoesNotExist:
        raise
    except Exception as e:
        raise


def get_model_location(entityDetails, params):
    """        
        check if model is available locally or not, 
            check if experiment exists with entityKey
                get latest run of current experiment
                if latest_run timestamp is >= to the CAUsedIn Training run timestamp
                    return correct model_location based on the local stored run
                else
                    get model downloaded location from params
                    create a new run and unzip the all the files in artifact location
            else
                get model downloaded location from params
                create a new run and unzip the all the files in artifact location
        else
            return model_location
    """
    try:
        # check if entity is system defined, return default model location in that case
        print(f"entityDetails.entityKey {entityDetails.entityXML.isSystemDefinedEntity}")
        if int(entityDetails.entityXML.isSystemDefinedEntity) == 1:
            return get_default_model_location(entityDetails.entityKey)

        model_location = entityDetails.entityXML.classifierDetails.modelURI
        # check if trained entity model supports sync
        if "mlruns" not in model_location:
            # check if model_location exists, otherwise error out
            if Path(model_location).exists():
                return model_location
            else:
                raise Exception(
                    f"Failed to find trained model locally for entity [{entityDetails.entityName}]"
                )

        DEFAULT_TRACKING_URI = str(Path.cwd()) + str(
            Path("\\..\\ContentAnalyzer\\bin\\classifier_models\\custom_trained_models\\mlruns")
        )
        DEFAULT_TRACKING_URI = DEFAULT_TRACKING_URI.replace("\\", "/")
        set_tracking_uri("file://" + DEFAULT_TRACKING_URI)
        # get experiment
        dataset_name = entityDetails.entityKey
        experiment_name = f"{dataset_name}_{entityDetails.entityId}"
        experiment = get_experiment_by_name(experiment_name)
        if experiment is None or type(experiment) == str:
            logging.debug(f"Trained model for entity [{dataset_name}] is not available locally.")
            experiment_id = create_experiment(experiment_name)
        else:
            logging.debug(f"Existing experiment found for dataset {experiment_name}")
            experiment_id = experiment.experiment_id

        # get latest run for the experiment
        latest_run = search_runs(
            experiment_ids=[str(experiment_id)],
            filter_string="",
            max_results=1,
            order_by=["attributes.end_time DESC"],
        )
        if len(latest_run) > 0:
            # a ran exists, compare the timestamps
            latest_run_dict = latest_run.iloc[0].to_dict()
            if latest_run_dict["status"] != "FINISHED":
                return create_new_run(experiment_id, entityDetails, params)
            else:
                current_run_epoch = (
                    latest_run_dict["end_time"] - pd.Timestamp("1970-01-01", tz="UTC")
                ).total_seconds()
                trained_model_epoch = (
                    entityDetails.entityXML.classifierDetails.CAUsedInTraining.lastModelTrainTime
                )
                # adding 10 seconds margin in case of same machine trained models
                if (int(current_run_epoch) + 10) >= int(trained_model_epoch):
                    model_location = str(
                        Path(url2pathname(latest_run_dict["artifact_uri"][7:])).relative_to(
                            Path.cwd()
                        )
                    )
                    return model_location
                else:
                    return create_new_run(experiment_id, entityDetails, params)

        else:
            return create_new_run(experiment_id, entityDetails, params)
    except Exception as e:
        error_message = f"Failed to find trained model for entity [{entityDetails.entityName}]."
        logging.error(error_message)
        raise ModelDoesNotExist(error_message)


@generic_cache(parent_cache_key="model_details")
def get_model_methods(entityDetails, params, logging, cache_token=None):
    try:
        model_location = get_model_location(entityDetails, params)
        model_config = load_model_config(model_location)
        os.chdir(model_location)
        spec = importlib.util.spec_from_file_location(
            "module", model_config["entry_points"]["file_name"]
        )
        model_test_code = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(model_test_code)
        model_params = {}
        model_params["logger"] = logging
        if "spacy_model_location" in model_config["entry_points"]:
            model_params["spacy_model_location"] = model_config["entry_points"][
                "spacy_model_location"
            ]
        if model_config["entry_points"]["load_model"]["method"] is not None:
            load_model_method_name = model_config["entry_points"]["load_model"]["method"]
            load_model_method = getattr(model_test_code, load_model_method_name)
            load_model_output = load_model_method()
            model_params.update(load_model_output)

        predict_method_name = model_config["entry_points"]["predict"]["method"]
        predict_method = getattr(model_test_code, predict_method_name)
        return model_params, predict_method
    except ModelDoesNotExist:
        raise
    except Exception:
        raise


def get_document_category(processing_input, params):
    results = {"doc_tags": []}
    try:
        if (
            "entitiesToExtractML" not in params
            or params["entitiesToExtractML"] is None
            or params["entitiesToExtractML"].strip() == ""
        ):
            return results

        model_entities = params["entitiesToExtractML"].strip(",").split(",")
        clean_input = get_text(processing_input, params)
        doc_tags = []
        current_working_dir = os.getcwd()
        if "eeCacheToken" not in params:
            params["eeCacheToken"] = "default"
        model_does_not_exists = []
        model_not_trained = []
        classification_failed = []
        for model_entity in model_entities:
            entity_name = model_entity
            try:
                logging.debug(f"{params['eeCacheToken']}_{model_entity}")
                try:
                    entityDetails = get_entity_wrapper(
                        int(model_entity), cache_token=f"{params['eeCacheToken']}_{model_entity}"
                    )
                except EntityDoesNotExist:
                    # entity might be disabled or deleted, skip it
                    continue
                entity_name = entityDetails.entityName
                if is_trained(entityDetails):
                    model_params, predict_method = get_model_methods(
                        entityDetails,
                        params,
                        logging,
                        cache_token=f"{params['eeCacheToken']}_{model_entity}",
                    )
                    output = predict_method(clean_input, model_params)
                    logging.debug(f"Model {entityDetails.entityKey}. Output {output}")
                    if int(output) == 1:
                        doc_tags.append(entityDetails.entityKey)
                else:
                    error_message = f"Model not trained for entity [{entity_name}]"
                    logging.error(error_message)
                    raise ModelNotTrained(error_message)
            except ModelDoesNotExist as e:
                model_does_not_exists.append(e)
            except ModelNotTrained as e:
                model_not_trained.append(e)
            except Exception as e:
                logging.exception(
                    f"Failed to ran categorizer for entity [{entity_name}]. Exception {e}"
                )
                classification_failed.append(e)
            finally:
                os.chdir(current_working_dir)
        results["doc_tags"] = doc_tags
        if (len(model_not_trained) + len(classification_failed) + len(model_does_not_exists)) > 1:
            # for now we don't have support to report individual entity failures
            raise Exception
        else:
            if len(model_does_not_exists) == 1:
                raise ModelDoesNotExist(model_does_not_exists[0])
            elif len(model_not_trained) == 1:
                raise ModelNotTrained(model_not_trained[0])
            elif len(classification_failed) == 1:
                raise Exception(classification_failed[0])
    except ModelDoesNotExist as e:
        results["ErrorCode"] = CA_ERROR_CODES["ModelDoesNotExist"]
        results["ErrorMessage"] = str(e)
        logging.error(e)
    except ModelNotTrained as e:
        results["ErrorCode"] = CA_ERROR_CODES["ModelNotTrained"]
        results["ErrorMessage"] = str(e)
        logging.error(e)
    except Exception as e:
        results["ErrorCode"] = CA_ERROR_CODES["ClassificationFailed"]
        results["ErrorMessage"] = "Failed to classify document. Please check logs for more details."
        logging.exception(f"Failed during document categorization. Exception {e}")
    finally:
        return results


def preProcess(params={}):
    response = {"ErrorCode": CA_ERROR_CODES["success"]}
    return response


def doAnalysis(processing_input, params={}):
    try:
        entity_model_results = get_document_category(processing_input, params)

        results = {"ErrorCode": CA_ERROR_CODES["success"], "ErrorMessage": None}
        if "ErrorCode" in entity_model_results:
            results["ErrorCode"] = entity_model_results["ErrorCode"]
            if "ErrorMessage" in entity_model_results:
                results["ErrorMessage"] = entity_model_results["ErrorMessage"]
        tags_found = entity_model_results["doc_tags"]
        if len(tags_found)>0:
            results["doc_tags"] = tags_found
            results["skipFields"] = ["doc_tags"]
            for tags in tags_found:
                results[tags] = ['1']
			
        return results
    except Exception as e:
        logging.exception(f"Failed to do the categorization. Exception {e}")


def doProcessing(child_queue, parent_queue, id, parent_pid):
    func_name = "doProcessing"
    func_str = f"{module_name}::{func_name}() - "

    checkParentAndKill(parent_pid, os.getpid())

    perf_counter = PerformanceCounter(logging)
    perf_counter.periodicLogTimer()
    while True:
        try:
            default_response = {"ErrorCode": CA_ERROR_CODES["success"], "ErrorMessage": None}
            msg = child_queue.get()
            if msg is not None:
                if msg["operation"] == "stop":
                    break
                elif msg["operation"] == "preProcess":
                    response = preProcess(msg["params"])
                elif msg["operation"] == "doAnalysis":
                    params = msg["params"]
                    try:
                        file_size = 0
                        perf_counter.start_stopwatch("DOC_TAGGER_TEST")
                        processing_input = ""
                        if "FilePathForProcessing" in params:
                            FilePathForProcessing = params["FilePathForProcessing"]
                            with open(
                                FilePathForProcessing.encode("utf-8"), "rU", encoding="utf-8"
                            ) as f:
                                processing_input = cleanRawTextForNER(f.read())
                            file_size = len(processing_input)
                        response = doAnalysis(processing_input, params)
                        del processing_input
                        del msg
                    except Exception as e:
                        logging.exception(f"Failed to categorize. Exception {e}", func_str)
                        response = default_response
                    finally:
                        perf_counter.stop_stopwatch("DOC_TAGGER_TEST", file_size)
                parent_queue.put(response)
        except Exception as e:
            logging.exception(f"Exception occured {e}.", func_str)
            parent_queue.put(default_response)
