from builtins import range
from builtins import object
import csv, re, os


class StructuredDocumentHandler(object):
    """
    for now only csv and tsv are coming into this category and will be handled separately
    TODO excel and tables
    """

    def __init__(self, text, Doc):
        self.Doc = Doc
        self.text = None
        self.records = None
        self.text = None
        self.type = self.get_document_type(text)
        if self.type is not None:
            self.text = text

    def get_document_type(self, text):
        content = StructuredDocumentHandler.clean_content(text)
        try:
            if self.is_tsv(content):
                return "tsv"
            elif self.is_csv(content):
                return "csv"
            else:
                return None
        except Exception:
            return None

    def is_csv(self, content):
        self.text = content
        content = re.sub("\\\\", "", content)
        reader = csv.reader(
            [re.sub("\\\\t", "\\t", each_line) for each_line in content.splitlines()[:30]],
            delimiter=",",
        )
        records = list(reader)
        return StructuredDocumentHandler.is_structured(records)

    def is_tsv(self, content):
        self.text = content
        reader = csv.reader(
            [re.sub("\\\\t", "\\t", each_line) for each_line in content.splitlines()[:30]],
            delimiter="\t",
        )
        records = list(reader)
        return StructuredDocumentHandler.is_structured(records)

    @staticmethod
    def is_structured(records):
        if len(records) > 5 and len(records[0]) > 1:
            first_record_len = len(records[0])
            record_idx = -1
            num_found = 0
            while num_found < 6 and record_idx < len(records) - 1:
                record_idx += 1
                if len(records[record_idx]) == 1 and records[record_idx][0].strip() == "":
                    continue
                if len(records[record_idx]) != first_record_len:
                    return False
                num_found += 1
            if num_found == 6:
                return True
        return False

    @staticmethod
    def clean_content(content):
        content = re.sub("(\\\\r\\\\n[\s]+)+", os.linesep, content)
        content = re.sub("\\n\\n", os.linesep, content)
        content = content.strip()
        return content

    @staticmethod
    def load_csv(content):
        content = StructuredDocumentHandler.clean_content(content)
        content = re.sub("\\\\", "", content)
        reader = csv.reader(
            [re.sub("\\\\t", "\\t", each_line) for each_line in content.splitlines()], delimiter=","
        )
        records = list(reader)
        return records

    @staticmethod
    def load_tsv(content):
        content = StructuredDocumentHandler.clean_content(content)
        reader = csv.reader(
            [re.sub("\\\\t", "\\t", each_line) for each_line in content.splitlines()],
            delimiter="\t",
        )
        records = list(reader)
        return records

    def perform_extraction(self, nlp, extension=None):
        entities = []
        """
            we are loading first 30 rows to verify whether document is a structured document or not
            load the file as either csv or tsv (xls, xlsx) during the extraction time
            in case extension is there use the extension for classification
        """
        if extension is not None:
            if extension == "csv":
                self.type = "csv"
            elif extension in ("tsv", "xls", "xlsx"):
                self.type = "tsv"

        if self.type == "csv":
            self.records = StructuredDocumentHandler.load_csv(self.text)
        else:
            self.records = StructuredDocumentHandler.load_tsv(self.text)

        if self.records is not None and type(self.records) == type([]) and len(self.records) > 0:
            start_idx = 0
            while len(self.records[start_idx]) == 1 and self.records[start_idx][0].strip() == "":
                start_idx += 1
            num_columns = len(self.records[start_idx])
            for column_idx in range(num_columns):
                column_values = [
                    row[column_idx] for row in self.records[1:] if len(row) == num_columns
                ]
                """
                    get person name count from first 100 rows or all the rows whichever is minimum
                    if person name count from current column is more than 40% then mark the entire
                    column as person column and return all the entries with basic filtering
                """
                num_columns_values = min(len(column_values), 100)
                min_count_required = 0.40 * num_columns_values
                column_text = f".{os.linesep} ".join(column_values[:num_columns_values])
                doc = nlp(column_text)
                # logic is specific to person and location entity
                persons_count = sum([1 for ent in doc.ents if ent.label_ == "PERSON"])
                locations_count = sum([1 for ent in doc.ents if ent.label_ == "LOCATION"])
                if persons_count > min_count_required or locations_count > min_count_required:
                    if persons_count > min_count_required:
                        # return all the rows of current column as person names with basic filtering
                        person_entities = [
                            {"text": self.clean_entity(person_name), "label_": "PERSON"}
                            for person_name in column_values
                            if self.basic_filtering(person_name)
                        ]
                        entities = entities + person_entities
                    elif locations_count > min_count_required:
                        # return all the rows of current column as person names with basic filtering
                        location_entities = [
                            {"text": self.clean_entity(location_name), "label_": "LOCATION"}
                            for location_name in column_values
                            if self.basic_filtering(location_name)
                        ]
                        entities = entities + location_entities

                else:
                    entities = entities + list(doc.ents)
        results = dict()
        results["ents"] = entities
        return results

    def basic_filtering(self, entity_text):
        entity_text = self.clean_entity(entity_text)
        exclude_chars_regex = re.compile(r"[^a-zA-Z\.\-\'\",\(\)\s]+")
        if exclude_chars_regex.search(entity_text) is not None:
            return False
        if len(entity_text) < 3:
            return False
        return True

    def clean_entity(self, entity_text):
        exclude_chars_beggining = "^[^a-zA-Z]+]"
        exclude_chars_end = "[^a-zA-Z]+$"
        # remove special characters from beginning of text
        entity_text = re.sub(exclude_chars_beggining, "", entity_text)
        # remove special characters from end of text
        entity_text = re.sub(exclude_chars_end, "", entity_text)
        return entity_text
