import os
import pickle
import json
import ujson
import numpy as np
import scipy
import time

from urllib.request import url2pathname
from collections import defaultdict
from pathlib import Path


from mlflow import log_artifact, log_params, log_metrics, get_artifact_uri
from sklearn.svm import OneClassSVM
from sklearn.model_selection import GridSearchCV
from sklearn.metrics.scorer import make_scorer

MODELS_CACHED = {}


class NDArrayEncoder(json.JSONEncoder):
    # https://discuss.mxnet.io/t/make-ndarray-json-serializable/1627
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


class OneClassModel:
    DEFAULT_MODEL_PARAMS = {"kernel": "rbf", "gamma": 0.1, "nu": 0.01, "is_grid_search": True}

    def __init__(
        self,
        logging,
        datasets_info_helper=None,
        dataset_info=None,
        is_test=False,
        params={},
        is_mlflow_log_output=True,
    ):
        try:
            self.logging = logging
            self.is_test = is_test
            self.dataset_info = dataset_info
            self.datasets_info_helper = datasets_info_helper
            self.is_mlflow_log_output = is_mlflow_log_output
            self.params = OneClassModel.DEFAULT_MODEL_PARAMS.copy()
            self.params.update(params)
            self.model = self.get_model()
        except Exception as e:
            raise

    def __call__(self, features, column_name="vectorizer"):
        try:
            if self.is_test:
                # check if the features are in sparse matrix form
                if type(features) == scipy.sparse.csr.csr_matrix:
                    # in case all features are having 0 value
                    # avoid model prediction in this case
                    if features.nnz == 0:
                        return -1, 0
                return self.test(features)
            else:
                # validation_split
                total_size = features.shape[0]
                train_data = features
                validation_data = features
                validation_size = total_size
                if total_size > 200:
                    validation_size = max(100, int(0.10 * total_size))
                    validation_data = features[:validation_size]
                    train_data = features[validation_size:]
                self.train(train_data)
                if self.is_mlflow_log_output:
                    self.save_model(validation_data)
                accuracy = self.get_accuracy(validation_data)

                stats = {
                    "accuracy": accuracy,
                    "training_samples_used": train_data.shape[0],
                    "validation_samples_used": validation_size,
                }
                log_metrics(stats)

                # TODO: parallel processing will be an issue here
                files_ = [
                    "model.pkl",
                    "grid_search_cv_results.txt",
                    "decision_function.txt",
                    "model_score.json",
                    "positive_words_weight.json",
                    "negative_words_weight.json",
                ]
                for file_ in files_:
                    try:
                        Path(file_).unlink()
                    except:
                        pass
                return stats
        except Exception as e:
            raise

    def get_accuracy(self, features):
        try:
            predict_output = self.model.predict(features)
            positive = sum(filter(lambda x: x == 1, predict_output))

            # decision function
            decision_function = self.model.decision_function(features)
            with open("decision_function.txt", "w") as f:
                f.write(ujson.dumps(decision_function.tolist()))
            log_artifact("decision_function.txt")

            return round(positive / len(predict_output), 2)
        except Exception:
            raise

    @staticmethod
    def one_class_custom_error(y, y_pred):
        sample_size = len(y_pred)
        incorrect_predictions = 0
        for pred in y_pred:
            if int(pred) == -1:
                incorrect_predictions += 1
        return incorrect_predictions / float(sample_size)

    def train(self, features):
        if "is_grid_search" in self.params and self.params["is_grid_search"]:
            model = OneClassSVM(random_state=42)
            parameters = {
                "kernel": ("linear", "poly", "rbf", "sigmoid"),
                "nu": [
                    0.5,
                    0.2,
                    0.1,
                    0.01,
                    0.001,
                    0.0001,
                ],  # TODO : we can use clustering to get the training error size
                "gamma": [
                    0.1,
                    0.01,
                    0.001,
                    "auto",
                ],  # TODO: check again in case of sklearn upgrade to 0.22
            }
            one_class_scorer = make_scorer(
                OneClassModel.one_class_custom_error, greater_is_better=False
            )
            clf = GridSearchCV(
                model, parameters, cv=5, scoring=one_class_scorer, refit=True, n_jobs=4
            )
            clf.fit(features, np.array([1] * features.shape[0]))
            self.model = clf.best_estimator_
            with open("grid_search_cv_results.txt", "w") as f:
                json_str = json.dumps({"test": clf.cv_results_}, cls=NDArrayEncoder, indent=4)
                f.write(json_str)
            log_artifact("grid_search_cv_results.txt")
            best_params = clf.best_params_
            log_params(best_params)
        else:
            self.model.fit(features)

    def score_samples(self, features):
        """
            Raw scoring functions of the samples
            TODO: in sklearn 0.22 this will be inbuilt in OneClassSVM 
            [score_samples](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/svm/_classes.py#L1272)
        """
        decision_function_ = self.model.decision_function(features)
        # print(f"Decision {decision_function_} Intercept {self.model.intercept_}")
        return decision_function_ - self.model.intercept_

    def test(self, features):
        try:
            output = self.model.predict(features)
            score = self.score_samples(features)
            # try:
            #     print(f"Coef {self.model.coef_}")
            # except:
            #     pass
            # print(score)
            # print(self.model.support_vectors_)
            return output, score.ravel()[0]
        except Exception as e:
            raise

    def get_model(self):
        try:
            model = None
            if self.dataset_info is not None:
                if self.is_test:
                    model_location = self.get_model_location()
                    model = OneClassModel.load_model(model_location, file_name="model.pkl")
                else:
                    model = OneClassSVM(
                        nu=self.params["nu"],
                        kernel=self.params["kernel"],
                        gamma=self.params["gamma"],
                        random_state=42,
                    )
            return model
        except Exception as e:
            raise

    def save_model(self, features):
        try:
            if self.dataset_info is not None:

                # get transformer of current run
                transformer = None
                transformer_path = Path(url2pathname(get_artifact_uri()[7:])).relative_to(
                    Path.cwd()
                ) / Path("transformer.pkl")
                
                if  transformer_path.exists():
                    with open(transformer_path, "rb") as f:
                        transformer = pickle.loads(f.read())

                    # save positive and negative words of the model
                    feature_words = transformer.get_feature_names()
                    # save positive words from support vectors
                    positive_words_weight = defaultdict(float)
                    features_dense = self.model.support_vectors_.todense()
                    for doc_idx, feature_weights in enumerate(features_dense):
                        for feature_idx, feature_weight in enumerate(feature_weights.tolist()[0]):
                            if feature_weight > 0:
                                positive_words_weight[feature_words[feature_idx]] += feature_weight

                    # get negative prediction words and their score
                    all_predictions = self.model.predict(features)
                    features_dense = features.todense()
                    negative_prediction_features = features_dense[all_predictions == -1]
                    negative_words_weight = defaultdict(float)
                    for doc_idx, feature_weights in enumerate(negative_prediction_features.tolist()):
                        for feature_idx, feature_weight in enumerate(feature_weights):
                            if feature_weight > 0:
                                negative_words_weight[feature_words[feature_idx]] += feature_weight

                    # get max and min positive and negative scores
                    score = self.model.decision_function(features)
                    scores = {
                        "positive": {
                            "min": np.min(score[score >= 0]),
                            "max": np.max(score[score >= 0]),
                        },
                        "negative": {"min": np.min(score[score < 0]), "max": np.max(score[score < 0])},
                    }
                    with open("model_score.json", "w") as f:
                        ujson.dump(scores, f)
                    with open("positive_words_weight.json", "w") as f:
                        ujson.dump(positive_words_weight, f)
                    with open("negative_words_weight.json", "w") as f:
                        ujson.dump(negative_words_weight, f)
                    log_artifact("model_score.json")
                    log_artifact("positive_words_weight.json")
                    log_artifact("negative_words_weight.json")

                with open("model.pkl", "wb") as f:
                    pickle.dump(self.model, f)
                log_artifact("model.pkl")                
        except Exception as e:
            raise

    def get_model_location(self):
        """
            do a solr query to get the dataset info,            
            http://localhost:22000/solr/datasets_info/select?q=datasetid:{dataset_name}
        """
        try:
            resp = OneClassModel.get_dataset_info(
                self.datasets_info_helper, self.dataset_info["name"]
            )
            if resp["artifact_location"]:
                return resp["artifact_location"]
            else:
                raise ValueError("Unable to get artifact location for dataset {dataset_name}")
        except Exception as e:
            raise

    @staticmethod
    def load_model(model_location, file_name="model.pkl"):
        global MODELS_CACHED
        if model_location in MODELS_CACHED:
            return MODELS_CACHED[model_location]
        try:
            model_path = Path(model_location) / Path(file_name)
            with open(model_path, "rb") as f:
                model = pickle.loads(f.read())
            MODELS_CACHED[model_location] = model
            return model
        except Exception as e:
            raise

    @staticmethod
    def get_dataset_info(solr_helper, dataset_name):
        try:
            params = {"q": f"dataset_id:{dataset_name}", "wt": "json"}
            resp = solr_helper.query(params)
            resp_json = resp.json()
            return resp_json["response"]["docs"][0]
        except Exception:
            raise
