import os
import re
import traceback
import warnings


def warn(*args, **kwargs):
    pass


warnings.warn = warn
from pathlib import Path

import nltk
import numpy as np
import pandas as pd

from nltk.stem.porter import PorterStemmer
from nltk.stem.wordnet import WordNetLemmatizer
from multiprocessing import Pool
from collections import defaultdict

from cvee_stopwords import stopwords, english_words
from CvEEConfigHelper import get_cpu_cores, get_dll_logger

# get nltk_models path
nltk.data.path.append(os.path.join(os.getcwd(), Path("../ContentAnalyzer/bin/nltk_models/")))

"""
    reason for initialization another logger here instead of receiving it in class params like other pipeline classes,
    we are trying to split the dataframe for multi core processing, as the main function clean_data is a class method
    passing it to the serializer was creating issue in case of logger parameter as it is c++ logger and not serializable 
    from python multiprocessing pickle
"""
logging = get_dll_logger()
module_name = "cvee_preprocess_data"


class DataPreprocessor:
    def __init__(self):
        self.pipe_name = "preprocessor"
        self.stopwords_set = set(stopwords)
        english_words_set = set(english_words)
        self.stopwords_set = self.stopwords_set.union(english_words_set)
        self.stemmer = PorterStemmer()
        self.lemmatizer = WordNetLemmatizer()
        self.lemmatizer_dict = DataPreprocessor.get_dict_lemmatizer()
        self.wpt = nltk.WordPunctTokenizer()
        self.status = "start_process"  # we will update this variable at each intermediate step, multiprocessing will be issue

    def __call__(self, data_frame, column_name="text"):
        func_name = "DataPreprocessor"
        func_str = f"{module_name}::{func_name}() - "

        if len(data_frame) > 20:
            logging.info("Doing parallel preprocessing opertaion.", func_str)
            return self.parallelize_dataframe(data_frame, column_name)
        else:
            return self.clean_data((data_frame, column_name))

    def parallelize_dataframe(self, data_frame, column_name):
        """
            snippet from https://towardsdatascience.com/make-your-own-super-pandas-using-multiproc-1c04f41944a1 \n
            Note: I have tried dask with multiple partition options but didn't see any speed gain as it is using
            threads internally to operate on the data which will not give speed gain in case of CPU intensive operation. \n
            We are getting more than 2x speed gain with this method with n_cores = 4
        """
        func_name = "parallelize_dataframe"
        func_str = f"{module_name}::{func_name}() - "

        try:
            cpu_cores_available = get_cpu_cores(is_logical=False)  # get physical cores only
            n_cores = min(cpu_cores_available, 8)
            logging.debug(f"Number of cores used for preprocessing is {n_cores}", func_str)
            data_frame_split = np.array_split(data_frame, n_cores)
            pool = Pool(n_cores)
            df = pd.concat(
                pool.map(
                    self.clean_data, zip(data_frame_split, [column_name] * len(data_frame_split))
                )
            )
            pool.close()
            pool.join()
            return df
        except Exception:
            raise

    def clean_data(self, params):
        """
            1. remove leading and trailing spaces
            2. lowercase the entire text
            3. sanitize all the line feeds with system specific and 
            remove multiple linefeeds and space with one only
            4. Tokenize text and drop any word with special characters
            or digits.
            5. Remove stopwords (having custom stopwords which is more extensive than nltk).
            6. Use lemmatizer to replace words with their root origin or lemma.
        """
        df, column_name = params
        df[column_name] = (
            df[column_name]
            .map(self.convert_text_to_tokens)
            .map(self.remove_non_alphabets)
            .map(self.remove_stopwords)
            .map(self.lemmatize_words_using_dict)
            .map(self.stem_words)
            .map(self.convert_tokens_to_text)
        )
        return df

    def convert_text_to_tokens(self, text):
        # tokens = self.wpt.tokenize(text)
        import string

        text = text.lower()
        # text = re.sub(r"\d+", " NUM ", text)
        text = re.sub(r"[:\-,?']+", " ", text)
        tokens = text.split()
        return tokens

    def convert_tokens_to_text(self, tokens):
        return " ".join(tokens)

    def remove_stopwords(self, tokens):
        tokens = [word for word in tokens if word not in self.stopwords_set]
        return tokens

    def stem_words(self, tokens):
        tokens = [self.stemmer.stem(word) for word in tokens if not word.isnumeric()]
        return tokens

    def lemmatize_words(self, tokens):
        tokens = [self.lemmatizer.lemmatize(word) for word in tokens]
        return tokens

    def lemmatize_words_using_dict(self, tokens):
        lemmatized_tokens = []
        for token in tokens:
            if token in self.lemmatizer_dict:
                lemmatized_tokens.append(self.lemmatizer_dict[token])
            else:
                lemmatized_tokens.append(token)
        return lemmatized_tokens

    def remove_non_alphabets(self, tokens):
        tokens = [word for word in tokens if len(re.findall("[^a-zA-Z]", word)) == 0]
        return tokens

    @staticmethod
    def get_dict_lemmatizer():
        lemmas_ = defaultdict(str)
        with open("lemmatization-en.txt") as f:
            for line in f:
                if not line.startswith("#"):
                    words = line.split()
                    if len(words) > 1:
                        root = words[0]
                        word = words[1]
                        lemmas_[word] = root
        return lemmas_


if __name__ == "__main__":
    # from CvEEConfigHelper import get_local_logger

    # logging = get_local_logger("CVEEGetData.log")
    df = pd.DataFrame()
    original_text = "hey there this seems a nice test for puc,tuation , lemmatizations and we are lot of\n\n\n stopwords dfafa323 324 22 www.google.com"
    df = df.append({"file_name": "test.txt", "text": original_text}, ignore_index=True)
    preprocessor = DataPreprocessor()
    # dataset_df = ddf.from_pandas(df, npartitions=6)
    df = preprocessor(df)
    print(f"Original text: \n {original_text}")
    print(f"Processed text: \n {df['text'][0]}")
