import json
import requests
from CvEEConfigHelper import loadRegValue, SOLR_TAGGER_PORT, SOLR_TAGGER_URL
from spacy.tokens import Span, Token
from CvEEUtil import util_get_dict_attr

NUM_ATTEMPTS = 5
solrTaggerPort = loadRegValue(SOLR_TAGGER_PORT, 22000)
SOLR_TAGGER_URL_WITH_PORT = SOLR_TAGGER_URL.format(solrTaggerPort)


class SolrTaggerComponent(object):
    tagger_session = requests.Session()

    def __init__(self, nlp, params={}):
        self.LOGGER_Generic = util_get_dict_attr(params, "logger", None)
        self.nlp = nlp
        self.PERSON_LABEL = "PERSON"
        self.GPE_DEP_LABEL = "GPE_DE"
        self.GPE_IN_LABEL = "GPE_IN"
        self.GPE_LABEL = "LOCATION"
        self.PERSON_LABEL_LONG = nlp.vocab.strings[self.PERSON_LABEL]
        Span.set_extension("is_person_with_surname", default=False)
        Token.set_extension("is_location_gz", default=False)

    def __call__(self, doc):
        # call solr tagger and get the person and geo pattern matches
        try:
            resp = SolrTaggerComponent.get_tagger_matches(doc.text)
            if isinstance(resp, Exception):
                self.LOGGER_Generic.debug(
                    "Got an exception with SolrTagger request. Exception {}.".format(resp)
                )
                return doc
            if resp.status_code != requests.codes.ok:
                self.LOGGER_Generic.debug(
                    "Unable to contact solr tagger at {}. Status Code {}.".format(
                        SOLR_TAGGER_URL_WITH_PORT, resp.status_code
                    )
                )
                return doc

            processed_response = SolrTaggerComponent.get_processed_response(resp)
            entities_found = {}
            entities = []
            if len(processed_response["PERSON"]) > 0:
                entities_found["PERSON"] = self.get_person_entities(
                    doc, processed_response["PERSON"]
                )
                if len(entities_found["PERSON"]) > 0:
                    entities = entities + entities_found["PERSON"]
            if len(processed_response["GEO_IN"]) > 0:
                entities_found["GPE"] = self.get_geo_entities(
                    doc,
                    processed_response["GEO_IN"],
                    processed_response["GEO_DEP"],
                    processed_response["GEO_OTHERS"],
                )
                if len(entities_found["GPE"]) > 0:
                    entities = entities + entities_found["GPE"]
            offset_flat_array = []
            final_entities = []
            # There is a bug with the way spacy handles token merge.
            # issue is already reported here https://github.com/explosion/spaCy/issues/2550
            # as this is not resolved yet, handling overlapping manually
            for entity in entities:
                is_person = False
                current_start = entity.start
                while current_start < entity.end:
                    current_span = doc[current_start : current_start + 1]
                    if (
                        current_span.has_extension("entity_label_")
                        and current_span._.entity_label_ is not None
                        and type(current_span._.entity_label_) == type([])
                        and "PERSON" in current_span._.entity_label_
                    ):
                        is_person = True
                        break
                    current_start += 1

                if is_person == False:
                    final_entities.append(entity)
                    offset_flat_array = offset_flat_array + list(range(entity.start, entity.end))

            for entity in doc.ents:
                if (
                    len(set(offset_flat_array).intersection(set(range(entity.start, entity.end))))
                    == 0
                ):
                    final_entities.append(entity)

            doc.ents = final_entities
        except Exception as e:
            self.LOGGER_Generic.error(
                "Exception occurred in SolrTaggerComponent. Exception : {}".format(e)
            )
        return doc

    def get_person_entities(self, doc, matches):
        entities = self.get_entities(doc, matches, self.PERSON_LABEL)
        # merge found person names with surname entity
        for entity in entities:
            try:
                next_span = doc[entity.end : entity.end + 1]
                if next_span._.is_surname == True:
                    new_entity = Span(
                        doc, entity.start, entity.end + 1, label=self.PERSON_LABEL_LONG
                    )
                    new_entity._.is_person_with_surname = True
                    entities = list(entities) + [new_entity]
            except Exception as e:
                self.LOGGER_Generic.info(
                    "Failed to create span for PERSON entity. Exception {}".format(e)
                )
        filtered_entities = []

        # drop any person entity which is of length less than 2
        for entity in entities:
            try:
                if len(entity.text.split(" ")) > 1:
                    # add the PERSON_GZ label to each token
                    current_start = entity.start
                    while current_start < entity.end:
                        if doc[current_start]._.entity_label_ is None:
                            doc[current_start]._.entity_label_ = []
                        doc[current_start]._.entity_label_.append(self.PERSON_LABEL + "_GZ")
                        current_start += 1

                    filtered_entities.append(entity)
            except Exception as e:
                self.LOGGER_Generic.info(
                    "Failed to create span for PERSON entity. Exception {}".format(e)
                )
        return filtered_entities

    def get_geo_entities(self, doc, independent_matches, dependent_matches, other_matches):
        ind_entities = self.get_entities(doc, independent_matches, self.GPE_IN_LABEL)
        startings = dict()
        for entity in ind_entities:
            startings[entity.start] = entity.end
        dep_entities = self.get_entities(doc, dependent_matches, self.GPE_DEP_LABEL)
        # put location_gz label for all the tokens found in others matches
        self.set_location_attribute(doc, other_matches)

        entities = ind_entities + dep_entities
        filtered_entities = []
        for entity in entities:
            try:
                """
                for each dependent geo entity we will look for
                1. nearby gpe independent entity (excluding spaces and punctutations)
                2. or an abbreviation of length 2,3 like US or USA with upper casing
                if above conditions are satisfied then we will recreate the current span with GPE label
                TODO: in case of Ohio University we should avoid giving Ohio as location. 
                      This needs entity disambiguation logic.
                """
                if entity.label_ == self.GPE_DEP_LABEL:
                    current_end = entity.end
                    current_end += 1
                    while current_end < len(doc) and (
                        doc[current_end].pos_ in ("PUNCT", "NUM") or doc[current_end].is_space
                    ):
                        current_end += 1
                    if current_end in startings or (
                        doc[current_end].is_upper and len(doc[current_end]) in (2, 3)
                    ):
                        entity = doc.char_span(
                            entity.start_char, entity.end_char, label=self.GPE_LABEL
                        )
                        if entity is not None:
                            filtered_entities.append(entity)
                else:
                    entity = doc.char_span(entity.start_char, entity.end_char, label=self.GPE_LABEL)
                    if entity is not None:
                        filtered_entities.append(entity)
            except Exception as e:
                self.LOGGER_Generic.info(
                    "Failed to create span for GPE entity. Exception {}".format(e)
                )
        # add the LOCATION_GZ to each token. This will be used for enity disambiguation later.
        for entity in filtered_entities:
            current_start = entity.start
            while current_start < entity.end:
                if doc[current_start]._.entity_label_ is None:
                    doc[current_start]._.entity_label_ = []
                doc[current_start]._.entity_label_.append(self.GPE_LABEL + "_GZ")
                current_start += 1
        return filtered_entities

    def set_location_attribute(self, doc, other_matches):
        """ set location attribute to each token which will be used to expand address entity """
        for start, end, _ in other_matches:
            try:
                temp_span = doc.char_span(start, end + 1, label=self.GPE_LABEL)
                if temp_span is None:
                    continue
                current_start = temp_span.start
                while current_start < temp_span.end:
                    token = doc[current_start]
                    token._.is_location_gz = True
                    current_start += 1
            except Exception as e:
                self.LOGGER_Generic.info("Failed to create span in set_location_attribute method.")

    @staticmethod
    def get_tagger_matches(text):
        method = "POST"
        headers = {"content-type": "text/plain", "cache-control": "no-cache"}
        attempts = 0
        """
        requesting matches from solr tagger with a timeout of 2 minutes for each request.
        in case of subsequent 5 failures, call will be abandoned and reported as an error in the logs
        """
        while attempts < NUM_ATTEMPTS:
            try:
                resp = SolrTaggerComponent.tagger_session.request(
                    method, SOLR_TAGGER_URL_WITH_PORT, data=text, headers=headers, timeout=120
                )
                attempts += 1
                if resp.status_code != requests.codes.ok:
                    continue
                else:
                    break
            except Exception as exp:
                attempts += 1
                resp = exp
        return resp

    @staticmethod
    def get_processed_response(resp):
        processed_response = {"PERSON": [], "GEO_IN": [], "GEO_DEP": [], "GEO_OTHERS": []}
        resp = json.loads(resp.text)
        """
        response structure
        [
            {
                "entityName": "person",
                "tagCount": 1,
                "tagList": [
                    {
                        "endOffset": 6,
                        "ids": [
                            329648
                        ],
                        "matchText": "Johann",
                        "startOffset": 0
                    }
                ]
            },
            {
                "entityName": "geo_in",
                "tagCount": 0,
                "tagList": []
            },
            {
                "entityName": "geo_dep",
                "tagCount": 0,
                "tagList": []
            }
        ]
        """
        resp["tags"] = json.loads(resp["tags"])
        for tag in resp["tags"]:
            matches = []
            if "tagList" in tag:
                for each_tag in tag["tagList"]:
                    match = (
                        each_tag["startOffset"],
                        each_tag["endOffset"] - 1,
                        each_tag["matchText"],
                    )
                    matches.append(match)
            processed_response[tag["entityName"].upper()] = matches
        return processed_response

    def get_entities(self, doc, matches, label):
        entities = []
        #  store the mapping of (end,start) of each entity for merging the nearby entities
        endings = dict()
        for start, end, text in matches:
            try:
                if not text.islower():
                    # check if there is an entity ending at current start
                    if start - 2 in endings:
                        prev_end = start - 2
                        prev_start = endings[prev_end]
                        # create a new span with previous entity start and current entity end
                        prev_entity = doc.char_span(prev_start, prev_end + 1, label=label)

                        curr_entity = doc.char_span(start, end + 1, label=label)
                        # create a new entity span with prev_start and current end
                        entity = doc.char_span(prev_start, end + 1, label=label)

                        """
                            during entity merge, check if the entity had a flag is_surname(assigned in previous pipeline)
                            retain is_surname and is_person_with_surname in new entity as this is critical in ourfiltering logic
                        """
                        if (
                            (
                                (
                                    prev_entity.has_extension("is_person_with_surname")
                                    and prev_entity._.is_person_with_surname == True
                                )
                            )
                            or (
                                (
                                    prev_entity.has_extension("is_surname")
                                    and prev_entity._.is_surname == True
                                )
                            )
                            or (
                                (
                                    curr_entity is not None
                                    and curr_entity.has_extension("is_surname")
                                    and curr_entity._.is_surname == True
                                )
                            )
                        ):
                            entity._.is_person_with_surname = True

                        if entity is not None and prev_entity is not None:
                            entities.remove(prev_entity)
                        start = prev_start
                    else:
                        entity = doc.char_span(start, end + 1, label=label)
                    if entity is not None:
                        endings[end] = start
                        entities.append(entity)
            except Exception as e:
                self.LOGGER_Generic.info(
                    "Failed to create span for {} entity [{}]. Exception {}".format(label, text, e)
                )
        return entities
