import logging
import traceback
import os
import sys
from time import time

import pandas as pd

from CvEEConfigHelper import get_available_memory, NotEnoughData

module_name = "cvee_get_train_data"

MIN_DOCUMENTS_THRESHOLD = 50


class Dataset:
    """        
        fetch data in batches from a solr url.
        docs to fetch and batch size will be based on 
        number of docs and size of training data along with available memory
    """

    def __init__(self, logging, solr_helper, batch_size=100):
        self.logging = logging
        self.solr_helper = solr_helper
        self.batch_size = batch_size
        self.raw_dataset = []
        self.dataset_df = None
        self.train_dataset = None
        self.validation_dataset = None

    def __call__(self):
        """
            get number of docs, and training data size,
            get available system memory,
            calculate number of docs to fetch,
            start fetching the docs in batch size
            do the feature extraction
            split the data in train and validation dataset with an 90, 10 split
            # avoid the data split in case supplied data is less than 100 docs
        """
        func_name = "Dataset"
        func_str = f"{module_name}::{func_name}() - "

        try:
            # get the total file size
            num_docs, total_size = self.get_stats()

            """
                get available system memory in GB
                leave at least 2 GB system memory out for other calculations
                in case there is an on going job for NER or RER, 2 GB might not be enough
            """
            if num_docs < MIN_DOCUMENTS_THRESHOLD:
                raise NotEnoughData(
                    f"Please provide at least {MIN_DOCUMENTS_THRESHOLD} documents for training."
                )

            available_memory = get_available_memory()
            usagable_memory = max(0, available_memory - 2)
            if usagable_memory != 0:
                available_memory = usagable_memory

            # convert total size in GB
            total_size_gb = round(total_size / (2 ** 30))

            # check if all the docs can be fitted into memory
            # get_num_docs = num_docs - 1200
            get_num_docs = num_docs
            if available_memory < total_size_gb:
                avg_doc_size = total_size / num_docs
                get_num_docs = (available_memory * (2 ** 30)) // avg_doc_size

            self.logging.info(f"Preparing to fetch {get_num_docs} from solr.", func_str)

            self.dataset_df = pd.DataFrame()
            start_time = time()
            # start docs fetching in batches
            for batch_index in range(0, get_num_docs, self.batch_size):
                docs = self.get_data(batch_index)
                for file_name, text in docs:
                    self.dataset_df = self.dataset_df.append(
                        {"file_name": file_name, "text": " ".join(text.lower().split()[:20000])},
                        ignore_index=True,
                    )

            end_time = time()
            self.logging.debug(
                f"Number of training data fetched {self.dataset_df.shape[0]}.", func_str
            )
            self.logging.debug(
                f"Memory footprint {sys.getsizeof(self.dataset_df)//(2**20)} MB", func_str
            )
            self.logging.debug(
                f"Time taken to load the dataset {end_time - start_time:.2f}", func_str
            )

            self.dataset_df.sort_values(by="text", inplace=True)
            return self.dataset_df
        except NotEnoughData:
            raise
        except Exception as e:
            self.logging.exception(f"Failed to get the trainig docs. Exception {e}", func_str)
            raise

    # def create_dataframe(self):
    #     self.dataset_df = pd.DataFrame()
    #     for file_name, text in self.dataset_df:
    #         self.dataset_df.append({"file_name": file_name, "text": text}, ignore_index=True)
    #     return self.dataset_df

    def get_data(self, batch_index):
        """
            http://localhost:22000/solr/training_dataset_9615ea24-47d9-11ea-b212-f8b156d21fc5/select?q=used_in_training:false&rows=10&start=10&fl=content,FileName&wt=json
        """
        func_name = "get_data"
        func_str = f"{module_name}::{func_name}() - "

        try:
            params = {
                # "fq": "used_in_training:false",
                "rows": self.batch_size,
                "start": batch_index,
                "fl": "content,FileName",
                "wt": "json",
            }
            resp = self.solr_helper.query(params)
            resp_json = resp.json()
            for doc in resp_json["response"]["docs"]:
                yield doc["FileName"], doc["content"]
        except Exception as e:
            self.logging.error(f"Something went wrong while accesing solr. Exception {e}", func_str)
            raise

    def get_stats(self):
        func_name = "get_stats"
        func_str = f"{module_name}::{func_name}() - "

        try:
            params = {
                "facet": "true",
                "json.facet": '{total_size:"sum(size)"}',
                "rows": 0,
            }
            resp = self.solr_helper.query(params)
            facet_resp = resp.json()
            num_docs = facet_resp["facets"]["count"]
            total_size = facet_resp["facets"]["total_size"]
            self.logging.info(
                f"Number of trainig docs {num_docs}. Training dataset size {total_size/(1024*1024):.2f} MB",
                func_str,
            )
            return num_docs, total_size
        except Exception as e:
            self.logging.error(f"Something went wrong while accesing solr. Exception {e}", func_str)
            raise


if __name__ == "__main__":
    from CvEEConfigHelper import get_local_logger
    from cvee_solr_helper import SolrHelper

    logging = get_local_logger("CVEEGetData.log")
    dataset_name = "training_dataset_9615ea24-47d9-11ea-b212-f8b156d21fc5"
    dataset_info = {
        "name": dataset_name,
    }

    solr_url = f"http://localhost:22000/solr/"

    solr_helper = SolrHelper(logging, solr_url, dataset_info["name"])

    dataset = Dataset(logging, solr_helper)

    dataset()
