import re
import os
import pickle
import pandas as pd
from pathlib import Path
from collections import defaultdict

from nltk.stem.porter import PorterStemmer

from cvee_stopwords import stopwords, english_words

# logger could be received as argument for all entry points

transformer_file = "transformer.pkl"
model_file = "model.pkl"

NLP = None

def transform_w2v(docs, spacy_model_location):
    import numpy as np
    import spacy
    
    global NLP
    if NLP is None:
        NLP = spacy.load(spacy_model_location)
    vecs = []
    vocab = NLP.vocab
    embedding_dim = vocab.vectors.shape[1]
    for doc in docs:
        tokens = [t for t  in doc.split()]
        vecs.append(np.mean([vocab[w].vector
                    for w in tokens if w in vocab.strings] or
                    [np.zeros(embedding_dim)], axis=0))
    return np.asarray(vecs)

def load_binary_file(path_):
    try:
        if path_.exists():
            with open(path_, "rb") as f:
                binary_output = pickle.loads(f.read())
            return binary_output
        else:
            raise Exception(f"File [{Path(Path.cwd()) / path_}] doesn't exists.")
    except Exception:
        raise


def get_dict_lemmatizer():
    lemmas_ = defaultdict(str)
    with open("lemmatization-en.txt") as f:
        for line in f:
            if not line.startswith("#"):
                words = line.split()
                if len(words) > 1:
                    root = words[0]
                    word = words[1]
                    lemmas_[word] = root
    return lemmas_


stopwords_set = set(stopwords)
english_words_set = set(english_words)
stopwords_set = stopwords_set.union(english_words_set)
stemmer = PorterStemmer()
lemmatizer_dict = get_dict_lemmatizer()

def clean_data(df):
    def convert_text_to_tokens(text):
        text = text.lower()
        text = re.sub(r"[:\-,?']+", " ", text)
        tokens = text.split()
        return tokens

    def convert_tokens_to_text(tokens):
        return " ".join(tokens)

    def remove_stopwords(tokens):
        for token in tokens:
            if token not in stopwords_set:
                yield token

    def stem_words(tokens):
        for token in tokens:            
            yield stemmer.stem(token)

    def lemmatize_words_using_dict(tokens):
        for token in tokens:
            if token in lemmatizer_dict:
                yield lemmatizer_dict[token]
            else:
                yield token        

    def remove_non_alphabets(tokens):
        for token in tokens:
            if token.isalpha():
                yield token        
    df["text"] = (
        df["text"]
        .map(convert_text_to_tokens)
        .map(remove_non_alphabets)
        .map(remove_stopwords)
        .map(lemmatize_words_using_dict)
        .map(stem_words)
        .map(convert_tokens_to_text)
    )
    return df



# load the model location from yaml file
def load_model():
    try:
        transformer_path = Path(transformer_file)
        model_path = Path(model_file)
        transformer = None
        model = None
        if transformer_path.exists():            
            transformer = load_binary_file(transformer_path)
        if model_path.exists():
            model = load_binary_file(model_path)
        else:
            raise Exception(f"Model Path [{str(model_path)}] does not exists.")
        return {"transformer": transformer, "model": model}
    except Exception as e:
        raise Exception(f"Failed to load model. Exception {e}")

def predict(text, params=None):
    output = False
    try:
        logger = params["logger"]
        if text is None:
            return output
        text = text.strip()
        if len(text) < 100:
            return output
        dataset_df = pd.DataFrame()
        dataset_df = dataset_df.append({"text": text}, ignore_index=True)
        dataset_df = clean_data(dataset_df)

        if params is not None:
            transformer = params["transformer"]
            model = params["model"]

            if transformer is not None:
                features = transformer.transform(dataset_df["text"].astype(str))
                if features.nnz == 0:
                    return output
            else:
                logger.debug("Using word2vec as feature extractor")
                spacy_model_location = params["spacy_model_location"]
                features = transform_w2v(dataset_df["text"].astype(str), spacy_model_location)
            prediction = model.predict(features)
            logger.debug(f"Model output {prediction}")
            output = int(prediction) == 1
    except Exception as e:        
        logger.exception(f"Failed to classify document. Exception {e}")
    finally:
        return output
