import ujson
import os
import pickle
import warnings

import numpy as np
import scipy

import spacy

# import matplotlib.pyplot as plt


def warn(*args, **kwargs):
    pass


warnings.warn = warn

from pathlib import Path
from collections import OrderedDict
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.pipeline import FeatureUnion
from sklearn.decomposition import PCA, TruncatedSVD
from mlflow import log_artifact

# from mlflow.sklearn import log_model, SERIALIZATION_FORMAT_CLOUDPICKLE

from cvee_stopwords import stopwords
from cvee_preprocess_data import DataPreprocessor
from cvee_solr_helper import SolrHelper

module_name = "cvee_extract_features"

TRANFORMER_CACHED = {}


class SklearnFeatureExtractor:
    def __init__(
        self,
        logging,
        transformer_type="tfidf",
        datasets_info_helper=None,
        dataset_info=None,
        is_test=False,
        vocab_size=1000,
        is_mlflow_log_output=True,
        nlp=None
    ):
        self.logging = logging
        self.transformer_type = transformer_type
        self.datasets_info_helper = datasets_info_helper
        self.dataset_info = dataset_info
        self.is_test = is_test
        self.vocab_size = vocab_size
        self.is_mlflow_log_output = is_mlflow_log_output
        self.nlp = nlp
        self.transformer = self.get_transformer()

    def __del__(self):
        del self.transformer

    def __call__(self, data_frame, column_name="text"):
        try:
            if self.is_test:
                extracted_features = self.transform_data(data_frame, column_name)
            else:
                extracted_features = self.fit_transform_data(data_frame, column_name)                
                # delete stop_words to save space
                self.transformer.stop_words_ = None
                if self.is_mlflow_log_output:
                    self.save_transformer()
            return extracted_features
        except Exception as e:
            raise

    def get_transformer(self):
        func_name = "get_transformer"
        func_str = f"{module_name}::{func_name}() - "

        try:
            transformer = None
            if self.dataset_info is not None:
                if self.is_test:
                    transformer_location = self.get_transformer_location()
                    self.logging.debug(f"Feature extractor location {transformer_location}")
                    transformer = SklearnFeatureExtractor.load_transformer(
                        transformer_location, file_name="transformer.pkl"
                    )
                else:
                    self.logging.debug(
                        f"{self.transformer_type} feature extractor getting used for training",
                        func_str,
                    )
                    if self.transformer_type == "tfidf":
                        transformer = self.tfidf_vectorizer(self.vocab_size)
                    elif self.transformer_type == "count":
                        transformer = self.count_vectorizer()
                    elif self.transformer_type == "embedding":
                        if not self.nlp:
                            # embedding vectorizer selected but no model was passed, raise error
                            raise Exception("Can't use 'embedding' transformer, missing word vectors model")
                        transformer = self.embedding_vectorizer(self.nlp)
                    else:
                        transformer = self.feature_union_tfidf_word_char()
            return transformer
        except Exception as e:
            raise

    def get_transformer_location(self):
        """
            do a solr query to get the dataset info,            
            http://localhost:22000/solr/datasets_info/select?q=datasetid:{dataset_name}
        """
        try:
            response = self.datasets_info_helper.default_query_resp()
            resp = response.json()
            resp = resp["response"]
            if "docs" in resp and len(resp["docs"]) > 0:
                resp = resp["docs"][0]
                if "artifact_location" in resp and resp["artifact_location"]:
                    return resp["artifact_location"]
                else:
                    raise ValueError("Unable to get artifact location for dataset {dataset_name}")
            else:
                raise Exception("Failed to get dataset info.")
        except Exception as e:
            raise

    @staticmethod
    def load_transformer(transformer_location, file_name="transformer.pkl"):
        global TRANFORMER_CACHED
        if transformer_location in TRANFORMER_CACHED:
            return TRANFORMER_CACHED[transformer_location]
        try:
            # transformer_path = Path(transformer_location) / Path(file_name)
            transformer_path = Path(transformer_location) / Path("transformer.pkl")
            if not transformer_path.exists():
                transformer_path = Path(transformer_location) / Path("vectorizer.pkl")
            with open(transformer_path, "rb") as f:
                transformer = pickle.loads(f.read())
            TRANFORMER_CACHED[transformer_location] = transformer
            return transformer
        except Exception as e:
            raise

    def save_transformer(self):
        try:
            if self.dataset_info is not None and self.nlp is None:
                with open("transformer.pkl", "wb") as f:
                    pickle.dump(self.transformer, f)
                log_artifact("transformer.pkl")

                feature_names = self.transformer.get_feature_names()
                with open("vectorizer_info.txt", "w") as f:
                    f.write(f"Transformer Used \n {self.transformer}\n")
                    f.write(f"Features {feature_names}")
                log_artifact("vectorizer_info.txt")

                # feature scores
                scores = self.transformer.idf_
                feature_scores = list(zip(feature_names, scores))
                with open("feature_scores.txt", "w") as f:
                    f.write(ujson.dumps(feature_scores))
                log_artifact("feature_scores.txt")

                # TODO: parallel processing will be an issue here
                files_ = ["transformer.pkl", "vectorizer_info.txt", "feature_scores.txt"]
                for file_ in files_:
                    try:
                        Path(file_).unlink()
                    except:
                        pass
        except Exception as e:
            raise

    def feature_union_tfidf_word_char(self):
        """
            not currently in use
        """
        vectorizer = FeatureUnion(
            [
                (
                    "word_vectorizer",
                    TfidfVectorizer(
                        sublinear_tf=True,
                        stop_words="english",
                        strip_accents="unicode",
                        analyzer="word",
                        token_pattern=r"\w{1,}",
                        ngram_range=(1, 3),
                        dtype=np.float32,
                        max_features=6000,
                    ),
                ),
                (
                    "char_vectorizer",
                    TfidfVectorizer(
                        sublinear_tf=True,
                        stop_words="english",
                        strip_accents="unicode",
                        analyzer="char",
                        ngram_range=(1, 3),
                        dtype=np.float32,
                        max_features=8000,
                    ),
                ),
                # ('count_vectorizer', CountVectorizer(
                # stop_words = 'english',
                # strip_accents='unicode',
                # # ngram_range=(1, 1),
                # dtype=np.float32))
            ]
        )
        return vectorizer

    def count_vectorizer(self):
        """
            not currently in use
            TODO: might try to combine these together in feature_union later
        """
        vectorizer = CountVectorizer(
            token_pattern="(?u)\\b\\w*[a-zA-Z]\\w*\\b", stop_words=stopwords
        )
        return vectorizer

    def tfidf_vectorizer(self, max_features=None):
        """
            in use feature extractor as of now        
        """
        vectorizer = TfidfVectorizer(
            token_pattern=r"\w{3,}",
            stop_words=stopwords,
            ngram_range=(1, 3),
            max_features=max_features,
            min_df=2,
            # max_df=0.8,
        )
        return vectorizer

    def embedding_vectorizer(self, nlp):
        """
            in use when datasetType = 'email'
        """
        vectorizer = EmbeddingVectorizer(nlp, stop_words=stopwords)
        return vectorizer

    def dimension_reduction(self, features, n_components=100):
        """            
            Using reduced features along with KMeans to find clusters
        """

        features_dense = features
        if scipy.sparse.issparse(features_dense):
            features_dense = features.todense()

        pca = PCA(n_components=n_components)
        return pca.fit_transform(features)

    def transform_data(self, data_frame, column_name):
        X = self.transformer.transform(data_frame[column_name].astype(str))
        return X

    def fit_transform_data(self, data_frame, column_name):
        """
            tfidf_vectorizer is used here for transformation.                
        """
        X = self.transformer.fit_transform(data_frame[column_name].astype(str))
        return X


class DimensionalityReduction:
    def __init__(self, logging, n_components=100, model_type="pca", **hyper_params):
        try:
            self.n_components = n_components
            self.logging = logging
            self.model_type = model_type
            self.hyper_params = hyper_params
            self.model = self.get_model()
        except Exception as e:
            raise

    def __call__(self, features, column_name="text"):
        try:
            if self.model_type == "pca" and scipy.sparse.issparse(features):
                features = features.todense()
            return self.model.fit_transform(features)
        except Exception as e:
            self.logging.exception(
                f"Failed to reduce features using {self.model_type}. Exception {e}"
            )
            raise

    def get_model(self):
        try:
            if self.model_type == "pca":
                return PCA(n_components=self.n_components, random_state=42, **self.hyper_params)
            elif self.model_type == "truncatedsvd":
                # hyper params can include number of iterations
                return TruncatedSVD(
                    n_components=self.n_components, random_state=42, **self.hyper_params
                )
            else:
                raise Exception(
                    f"Models supported are pca and truncatedsvd. Input provided [{self.model_type}]"
                )
        except Exception as e:
            self.logging.exception(f"Failed to initialize {self.model_type}. Exception {e}")
            raise


class EmbeddingVectorizer:
    """
        Transform texts to word2vec features, using Spacy NLP model vocabulary
    """

    def __init__(self, nlp, stop_words):
        self.vocab = nlp.vocab
        self.stop_words = stop_words
        self.dim = self.vocab.vectors.shape[1]

    def fit(self, X, y):
        return self

    def transform(self, X):
        vecs = []
        for doc in X:
            tokens = [t for t  in doc.split() if t not in self.stop_words]
            vecs.append(np.mean([self.vocab[w].vector
                        for w in tokens if w in self.vocab.strings] or
                        [np.zeros(self.dim)], axis=0))
        return np.asarray(vecs)

    def fit_transform(self, X):
        return self.transform(X)

    def fit_transform_data(self, X):
        return self.transform(X)    

def get_features(df, transformer):
    """
        tfidf_vectorizer is used here for transformation.                
    """
    X = transformer.transform(df.text.astype(str))
    return X


if __name__ == "__main__":
    import pandas as pd

    from CvEEConfigHelper import get_local_logger

    logging = get_local_logger("CVEEGetData.log")

    df = pd.DataFrame()
    df = df.append(
        {
            "file_name": "test.txt",
            "text": "hey there this seems a nice test for puc,tuation , lemmatizations and we are lot of\n\n\n stopwords",
        },
        ignore_index=True,
    )
    print(f"{'='*30} Dataframe {'='*30}")
    print(df.head(2))
    preprocessor = DataPreprocessor()
    df = preprocessor(df, column_name="text")
    dataset_info = {"name": "technical"}
    feature_extractor = SklearnFeatureExtractor(logging, dataset_info=dataset_info)
    X = feature_extractor(df, column_name="text")
    print(f"{'='*30} Shape of Dataframe {'='*30}")
    print(f"Shape of transformed data {X.shape}")
    # print(f"{'='*30} Features {'='*30}")
    # print(features)
