import re
import os
import pickle
import pandas as pd
from pathlib import Path
from collections import defaultdict

from nltk.stem.porter import PorterStemmer

from cvee_stopwords import stopwords, english_words

# logger could be received as argument for all entry points

transformer_file = "transformer.pkl"
model_file = "model.pkl"


def load_binary_file(path_):
    try:
        if path_.exists():
            with open(path_, "rb") as f:
                binary_output = pickle.loads(f.read())
            return binary_output
        else:
            raise Exception(f"File [{Path(Path.cwd()) / path_}] doesn't exists.")
    except Exception:
        raise


def get_dict_lemmatizer():
    lemmas_ = defaultdict(str)
    with open("lemmatization-en.txt") as f:
        for line in f:
            if not line.startswith("#"):
                words = line.split()
                if len(words) > 1:
                    root = words[0]
                    word = words[1]
                    lemmas_[word] = root
    return lemmas_


def clean_data(df):
    stopwords_set = set(stopwords)
    english_words_set = set(english_words)
    stopwords_set = stopwords_set.union(english_words_set)
    stemmer = PorterStemmer()
    lemmatizer_dict = get_dict_lemmatizer()

    def convert_text_to_tokens(text):
        text = text.lower()
        text = re.sub(r"[:\-,?']+", " ", text)
        tokens = text.split()
        return tokens

    def convert_tokens_to_text(tokens):
        return " ".join(tokens)

    def remove_stopwords(tokens):
        tokens = [word for word in tokens if word not in stopwords_set]
        return tokens

    def stem_words(tokens):
        tokens = [stemmer.stem(word) for word in tokens if not word.isnumeric()]
        return tokens

    def lemmatize_words_using_dict(tokens):
        lemmatized_tokens = []
        for token in tokens:
            if token in lemmatizer_dict:
                lemmatized_tokens.append(lemmatizer_dict[token])
            else:
                lemmatized_tokens.append(token)
        return lemmatized_tokens

    def remove_non_alphabets(tokens):
        tokens = [word for word in tokens if len(re.findall("[^a-zA-Z]", word)) == 0]
        return tokens

    df["text"] = (
        df["text"]
        .map(convert_text_to_tokens)
        .map(remove_non_alphabets)
        .map(remove_stopwords)
        .map(lemmatize_words_using_dict)
        .map(stem_words)
        .map(convert_tokens_to_text)
    )
    return df


# load the model location from yaml file
def load_model():
    try:
        transformer_path = Path(transformer_file)
        model_path = Path(model_file)
        transformer = load_binary_file(transformer_path)
        model = load_binary_file(model_path)
        return {"transformer": transformer, "model": model}
    except Exception as e:
        raise Exception(f"Failed to load model. Exception {e}")


def predict(text, params=None):
    output = False
    try:
        if text is None:
            return output
        text = text.strip()
        if len(text) < 100:
            return output
        dataset_df = pd.DataFrame()
        dataset_df = dataset_df.append({"text": text}, ignore_index=True)
        dataset_df = clean_data(dataset_df)

        if params is not None:
            transformer = params["transformer"]
            model = params["model"]

            features = transformer.transform(dataset_df["text"].astype(str))
            if features.nnz == 0:
                return output
            prediction = model.predict(features)
            params["logger"].error(f"Model output {prediction}")
            output = int(prediction) == 1
    except Exception as e:
        if "logger" in params:
            params["logger"].exception(f"Failed to classify document. Exception {e}")
    finally:
        return output
