import os
import traceback
from time import time
from urllib.request import url2pathname
from pathlib import Path
from collections import Counter
from zipfile import ZipFile

from mlflow import (
    end_run,
    active_run,
    create_experiment,
    get_artifact_uri,
    get_experiment_by_name,
    start_run,
    set_tracking_uri,
    log_artifact,
    log_param,
)

from cvee_extract_features import SklearnFeatureExtractor, DimensionalityReduction
from cvee_get_train_data import Dataset
from cvee_oneclass_svm import OneClassModel
from cvee_preprocess_data import DataPreprocessor
from cvee_solr_helper import SolrHelper
from CvEEConfigHelper import (
    CA_ERROR_CODES,
    SOLR_TAGGER_PORT,
    load_spacy_model,
    checkParentAndKill,
    get_dll_logger,
    loadRegValue,
    update_training_status,
    store_model_config,
    NotEnoughData,
)
from cvee_get_entities import (
    get_entity,
    ModelTrainingStatus,
)
from cvee_models import KMeansModel

module_name = "cvee_train_classifier"

logging = get_dll_logger()

solrTaggerPort = loadRegValue(SOLR_TAGGER_PORT, 22000)
SOLR_URL = "http://localhost:{}/solr".format(solrTaggerPort)
ENTITIES_ID_MAP = None
CLEANUP_STALE_MODELS_TIME = 24 * 60 * 60  # 24 hours

DEFAULT_TRACKING_URI = str(Path.cwd()) + str(
    Path("\\..\\ContentAnalyzer\\bin\\classifier_models\\custom_trained_models\\mlruns")
)
DEFAULT_TRACKING_URI = DEFAULT_TRACKING_URI.replace("\\", "/")
set_tracking_uri("file://" + DEFAULT_TRACKING_URI)

NLP = None
MIN_DOCUMENTS_THRESHOLD = 50


class Classifier:
    def __init__(self):
        self.pipeline = []
        self.pipeline_function_map = {}

    def __call__(self, dataset_name=None, dataset_df=None, is_test=False):
        func_name = "Classifier"
        func_str = f"{module_name}::{func_name}() - "
        total_train_start_time = time()
        try:
            if "dataset" in self.pipeline:
                pipe_name = "dataset"
                dataset_df = self.pipeline_function_map[pipe_name]()
                self.pipeline.remove(pipe_name)
        except NotEnoughData:
            raise
        except Exception as e:
            raise Exception(f"Failed to process pipeline get_data. Exception [{e}]") from e

        for pipe_name in self.pipeline:
            try:
                if not is_test:
                    logging.info(f"Processing Pipeline [{pipe_name}]", func_str)
                    log_param("current_step", pipe_name)
                start_time = time()
                dataset_df = self.pipeline_function_map[pipe_name](dataset_df, "text")
                end_time = time()
                if not is_test:
                    logging.info(
                        f"Time taken to process [{pipe_name}] [{end_time - start_time:.2f}]",
                        func_str,
                    )
            except Exception as e:
                raise Exception(f"Failed to process pipeline [{pipe_name}]. Exception [{e}]")
        total_train_end_time = time()
        if not is_test:
            log_param("current_step", "completed")
            logging.info(
                f"Total time taken to train the model is [{total_train_end_time - total_train_start_time:.2f}]"
            )

        return dataset_df

    def add_pipe(self, pipe_name, function_):
        """
            add cleaning pipeline and feature extraction pipeline
        """
        if pipe_name not in self.pipeline:
            self.pipeline.append(pipe_name)
            self.pipeline_function_map[pipe_name] = function_
        return pipe_name


def update_latest_artifact(datasets_info_helper, dataset_name, entity_id, additional_req_params):
    func_name = "update_latest_artifact"
    func_str = f"{module_name}::{func_name}() - "

    # update artifact location in solr
    try:
        artifact_path = ""
        model_id = ""
        if additional_req_params and additional_req_params["training_status"] == "completed":
            artifact_path = Path(url2pathname(get_artifact_uri()[7:])).relative_to(Path.cwd())
            model_id = active_run().info.run_id

        resp = datasets_info_helper.default_query_resp()
        resp_json = resp.json()
        request_data = [
            {
                # TODO: remove this once schema is corrected
                "FileName": "a",
                "dataset_id": dataset_name,
                "entity_id": entity_id,
                "model_id": model_id,
                "artifact_location": str(artifact_path),
            }
        ]

        request_data[0].update(additional_req_params)

        if int(resp_json["response"]["numFound"]) > 0:
            # delete existing entry before pushing new
            datasets_info_helper.delete(f"dataset_id:{dataset_name}")
        resp = datasets_info_helper.update(request_data)

        # zip all the artifact files with model Id
        try:
            if additional_req_params and additional_req_params["training_status"] == "completed":
                if artifact_path.exists():
                    with ZipFile(artifact_path / Path(f"{model_id}.zip"), "w") as zip_:
                        for file_ in artifact_path.iterdir():
                            if file_.is_file() and file_.suffix != ".zip":
                                zip_.write(file_, file_.name)
        except Exception as e:
            logging.exception(f"Failed to zip files. Exception [{e}]", func_str)
    except Exception as e:
        logging.exception(f"Failed to update latest ran in solr. Exception [{e}]", func_str)


def end_run_wrapper():
    try:
        end_run()
    except Exception as e:
        logging.error(f"Exception occured {e}")


def train_classifier(entityId, dataset_name):
    func_name = "train_classifier"
    func_str = f"{module_name}::{func_name}() - "
    global NLP  # for spacy model
    model_location = ""
    total_samples = 0
    datasets_info_helper = None
    try:
        dataset_info = {"name": dataset_name, "url": SOLR_URL}
        experiment_name = f"{dataset_name}_{entityId}"
        experiment = get_experiment_by_name(experiment_name)
        experiment_id = None
        if experiment is None or type(experiment) == str:
            logging.debug(f"Creating new experiment with name [{experiment_name}]", func_str)
            experiment_id = create_experiment(experiment_name)
        else:
            logging.debug(f"Existing experiment found for dataset [{experiment_name}]", func_str)
            experiment_id = experiment.experiment_id

        with start_run(experiment_id=experiment_id) as run:
            try:

                # clear previous values
                additional_attributes = {
                    "classifierAccuracy": "",
                    "validationSamplesUsed": "",
                    "trainingSamplesUsed": "",
                    "totalSamples": "",
                    "modelGUID": "",
                    "errLogMessage": "",
                    "errorCode": "",
                }
                update_training_status(
                    entityId,
                    "",
                    training_status=ModelTrainingStatus.TRAINING_STARTED.value,
                    additional_attributes=additional_attributes,
                )

                model_location = Path(url2pathname(get_artifact_uri()[7:])).relative_to(Path.cwd())
                logging.debug(f"Starting experiment {experiment_id}", func_str)
                start_time = time()
                solr_url = dataset_info["url"]
                default_query = f"entity_id:{entityId}"
                training_data_helper = SolrHelper(
                    logging, solr_url, "datasets", default_query=default_query
                )
                datasets_info_helper = SolrHelper(
                    logging, solr_url, "datasets_info", default_query=default_query
                )
                dataset = Dataset(logging, training_data_helper)
                preprocessor = DataPreprocessor()

                # check if the task was spawned by case manager, create the extractor based on the input
                transformer_type = "tfidf"  # by default use tfidf
                entity_details = get_entity(entityId)
                dataset_type = entity_details.entityXML.classifierDetails.datasetType
                if dataset_type is not None and dataset_type == "email":
                    # use embedding model for feature extraction
                    if NLP is None:
                        NLP = (
                            load_spacy_model()
                        )  # TODO: load lg spacy model for now, evaluate to use smaller models
                    transformer_type = "embedding"

                # use kmeans clustering to find bigger cluster
                clustering_pipeline = Classifier()
                clustering_pipeline.add_pipe("dataset", dataset)
                clustering_pipeline.add_pipe("preprocessor_kmeans", preprocessor)
                clustering_feature_extractor = SklearnFeatureExtractor(
                    logging,
                    datasets_info_helper=datasets_info_helper,
                    dataset_info=dataset_info,
                    vocab_size=None,
                    is_mlflow_log_output=False,
                    transformer_type=transformer_type,
                    nlp=NLP,
                )
                clustering_pipeline.add_pipe(
                    "feature_extractor_kmeans", clustering_feature_extractor
                )
                feature_reduction = DimensionalityReduction(
                    logging, n_components=20, model_type="truncatedsvd"
                )
                clustering_pipeline.add_pipe("feature_extractor_kmeans_reduced", feature_reduction)
                n_clusters = 2
                kmeans_model = KMeansModel(logging, n_clusters=n_clusters)
                clustering_pipeline.add_pipe("train_model_kmeans", kmeans_model)
                output_clusters = clustering_pipeline()
                max_cluster_size = 0
                max_cluster_index = -1
                for cluster_index in range(n_clusters):
                    cluster_size = sum(output_clusters == cluster_index)
                    if max_cluster_size < cluster_size:
                        max_cluster_size = cluster_size
                        max_cluster_index = cluster_index

                validation_samples_size = max_cluster_size
                if max_cluster_size > 200:
                    validation_samples_size = max(100, int(0.10 * max_cluster_size))
                    logging.info(
                        f"Total Documents [{dataset.dataset_df.shape[0]}] Documents used for training [{max_cluster_size}] Validation Documents (Hold out) [{validation_samples_size}]"
                    )
                else:
                    logging.info(
                        f"Total Documents [{dataset.dataset_df.shape[0]}] Documents used for training [{max_cluster_size}] Validation Documents (entire training set) [{validation_samples_size}]"
                    )
                total_samples = dataset.dataset_df.shape[0]
                if total_samples < MIN_DOCUMENTS_THRESHOLD:
                    raise NotEnoughData(
                        f"Please provide at least [{MIN_DOCUMENTS_THRESHOLD}] documents for training."
                    )
                # fiter out documents of bigger cluster
                filtered_df = dataset.dataset_df.iloc[output_clusters == max_cluster_index]

                # shuffle dataframe
                filtered_df = filtered_df.sample(frac=1, random_state=42)

                # get total unique words
                all_words = []
                for current_words in filtered_df["text"].str.split().values:
                    all_words += current_words
                total_unique_words = len(Counter(all_words))
                num_documents = filtered_df.shape[0]

                # clear up previous outputs,
                del output_clusters
                del dataset
                del feature_reduction
                del clustering_feature_extractor
                del kmeans_model
                del clustering_pipeline
                del all_words
                # use filtered df for one class classification
                classifier = Classifier()

                # update training_status to started
                additional_req_params = {"training_status": "started"}
                update_latest_artifact(
                    datasets_info_helper, dataset_name, entityId, additional_req_params
                )

                # adding classifier pipelines
                classifier.add_pipe("preprocessor_ocsvm", preprocessor)

                feature_extractor = SklearnFeatureExtractor(
                    logging,
                    datasets_info_helper=datasets_info_helper,
                    dataset_info=dataset_info,
                    vocab_size=max(int(total_unique_words / num_documents), 100),
                    transformer_type=transformer_type,
                    nlp=NLP,
                )
                classifier.add_pipe("feature_extractor_ocsvm", feature_extractor)

                one_class_model = OneClassModel(
                    logging, datasets_info_helper=datasets_info_helper, dataset_info=dataset_info
                )

                classifier.add_pipe("train_model_ocsvm", one_class_model)

                # start executing the pipeline
                response = classifier(dataset_df=filtered_df)

                end_time = time()
                training_time = end_time - start_time

                """
                    check if accuracy is lower than threshold (later could be registry configurable or entity based)
                        if accuracy is less than threshold
                            check if documents provided are less than 200 (again need to define in a common place)
                                add suggestion in ErrorMessage to add more documents
                            check else if kmeans clustering output have significant documents in smaller cluster
                            (like if smaller cluster have anything greater than 40%)
                                add suggestion in ErrorMessage to add more similar documents
                            change status to trained but not usable                        
                
                """
                THRESHOLD = 0.7
                accuracy = float(response["accuracy"])
                training_status_solr = "completed"
                training_status = ModelTrainingStatus.TRAINING_COMPLETED.value
                error_message = ""
                error_code = 0
                if accuracy < 0.7:
                    if total_samples < 200:
                        error_message = "Accuarcy is too low. Please try adding more documents in the training set."
                        error_code = CA_ERROR_CODES["LowAccuracyLessData"]
                    elif (total_samples - num_documents) / total_samples > 0.4:
                        error_message = "Accuarcy is too low. Please try adding similar documents in the training set."
                        error_code = CA_ERROR_CODES["LowAccuracyDissimlarData"]
                    else:
                        error_message = "Accuarcy is too low. Please try adding more documents in the training set."
                        error_code = CA_ERROR_CODES["LowAccuracyLessData"]
                    training_status_solr = "trained_but_usable"
                    training_status = ModelTrainingStatus.TRAINED_NOT_USABLE.value

                # update training_status to completed
                additional_req_params = {
                    "training_status": training_status_solr,
                    "accuracy": str(response["accuracy"]),
                    "training_time": f"{training_time:.2f}",
                }

                try:
                    store_model_config(
                        model_location,
                        test_file="test_one_class_svm.py",
                        dependent_files=["cvee_stopwords.py", "lemmatization-en.txt"],
                        predict_method="predict",
                        load_model_method="load_model",
                    )
                except Exception as e:
                    logging.exception(f"Failed to store config file. Exception [{e}]")
                    return False

                try:
                    update_latest_artifact(
                        datasets_info_helper, dataset_name, entityId, additional_req_params
                    )
                except Exception as e:
                    logging.exception(
                        f"Failed to update classifier details in solr. Exception [{e}]"
                    )

                try:
                    additional_attributes = {
                        "classifierAccuracy": str(response["accuracy"]),
                        "validationSamplesUsed": response["validation_samples_used"],
                        "trainingSamplesUsed": response["training_samples_used"],
                        "totalSamples": total_samples,
                        "modelGUID": run.info.run_id,
                        "errLogMessage": error_message,
                        "errorCode": error_code,
                    }
                    update_training_status(
                        entityId,
                        model_location,
                        training_status=training_status,
                        additional_attributes=additional_attributes,
                    )
                    logging.info(
                        f"Successfully updated training status for the entity [{entityId}]."
                    )
                except Exception as e:
                    logging.exception(
                        f"Failed to update entity details in Database. Exception [{e}]"
                    )
                    return False

                return response
            except NotEnoughData as e:
                raise
            except Exception as e:
                try:
                    # try logging the traceback output in mlflow for debugging purpose
                    with open(f"train_failed_{entityId}.txt", "w") as f:
                        f.write(traceback.format_exc())
                    log_artifact(f"train_failed_{entityId}.txt")
                    Path(f"train_failed_{entityId}.txt").unlink()
                except:
                    pass
                raise
            finally:
                end_run_wrapper()
    except Exception as e:
        try:
            # update training_status to failed
            additional_req_params = {"training_status": "failed"}
            update_latest_artifact(
                datasets_info_helper, dataset_name, entityId, additional_req_params
            )
        except:
            pass
        logging.error(f"Failed to train. Exception occured [{e}].", func_str)
        try:
            additional_attributes = {
                "classifierAccuracy": "",
                "validationSamplesUsed": "",
                "trainingSamplesUsed": "",
                "totalSamples": "",
                "modelGUID": "",
                "errLogMessage": "Failed to train classifier. Please check ContentAnalyzer.log for more details.",
                "errorCode": CA_ERROR_CODES["ModelTrainingFailed"],
            }
            if isinstance(e, NotEnoughData):
                additional_attributes[
                    "errLogMessage"
                ] = f"Please provide at least [{MIN_DOCUMENTS_THRESHOLD}] documents for training."
                additional_attributes["errorCode"] = CA_ERROR_CODES["NotEnoughData"]
            if total_samples is not None:
                additional_attributes.update({"totalSamples": total_samples})
            update_training_status(
                entityId,
                model_location,
                training_status=ModelTrainingStatus.TRAINING_FAILED.value,
                additional_attributes=additional_attributes,
            )
        except Exception as e:
            logging.exception(f"Failed to update entity details in Database. Exception [{e}]")
        return False
    finally:
        end_run_wrapper()


def preProcess(params={}):
    response = {"ErrorCode": CA_ERROR_CODES["success"]}
    return response


def doProcessing(child_queue, parent_queue, id, parent_pid):
    func_name = "doProcessing"
    func_str = f"{module_name}::{func_name}() - "
    global ENTITIES_ID_MAP

    checkParentAndKill(parent_pid, os.getpid())

    while True:
        try:
            msg = child_queue.get()
            if msg is not None:
                if msg["operation"] == "stop":
                    break
                elif msg["operation"] == "preProcess":
                    response = preProcess(msg["params"])
                    parent_queue.put(response)
                elif msg["operation"] == "doAnalysis":
                    params = msg["params"]
                    # an async operation, no need to send a response immediately
                    train_classifier(params["entityId"], params["datasetId"])
                    # parent_queue.put("Done")
                del msg
        except Exception as e:
            logging.exception(f"Exception occured [{e}].", func_str)


if __name__ == "__main__":
    # import pandas as pd

    # from cvee_preprocess_data import DataPreprocessor
    # from cvee_extract_features import SklearnFeatureExtractor
    # from cvee_oneclass_svm import OneClassModel
    # from cvee_solr_helper import SolrHelper

    dataset_name = "lessdocs"
    dataset_info = {"name": dataset_name, "url": "http://localhost:22000/solr"}

    train_classifier(113, dataset_name)
