import traceback
import requests

from enum import Enum
from dataclasses import asdict, dataclass
from typing import List, Optional
from time import time
from ctypes import cdll, c_char_p, POINTER, cast, create_string_buffer, c_char, string_at

from dacite import from_dict
from dacite import data
from xmltodict import parse

from CvEEConfigHelper import (
    queryCS,
    getLoadedDLL,
    generic_cache,
    EntityDoesNotExist,
    loadRegValue,
    SOLR_TAGGER_PORT,
)

UPDATE_ENTITY_SESSION = requests.Session()


class EEEntityType(Enum):
    NONE = 0
    NER = 1
    RER = 2
    DE = 3
    ML_MODEL = 4


class ModelTrainingStatus(Enum):
    NONE = 0
    ENTITY_CREATED = 1
    TRAINING_STARTED = 2
    TRAINING_FAILED = 3
    TRAINING_COMPLETED = 4
    TRAINING_COMPLETED_WITH_ERROR = 5
    TRAINED_NOT_USABLE = 6


@dataclass
class ContentAnalyzer:
    caUrl: str = ""
    clientId: str = ""
    cloudId: str = ""
    cloudName: str = ""
    lastModelTrainTime: str = ""


@dataclass
class ContentAnalyzerDetails:
    contentAnalyzerList: Optional[List[ContentAnalyzer]] = None


@dataclass
class Error:
    errorCode: str = ""
    errLogMessage: str = ""


@dataclass
class ClassifierDetails:
    syncedContentAnalyzers: Optional[ContentAnalyzerDetails] = None
    datasetStorageType: str = ""
    datasetType: str = ""
    trainDatasetURI: str = ""
    trainingStatus: str = ""
    validationSamplesUsed: str = ""
    trainingSamplesUsed: str = ""
    totalSamples: str = ""
    classifierAccuracy: str = ""
    modelStorageType: str = ""
    modelURI: str = ""
    modelGUID: str = ""
    CAUsedInTraining: Optional[ContentAnalyzer] = None
    err: Optional[Error] = None


@dataclass
class EntityXML:
    entityKey: str = ""
    isSystemDefinedEntity: str = ""
    classifierDetails: Optional[ClassifierDetails] = None
    keywords: str = ""
    proximityRange: str = "300"


@dataclass
class EntityDetails:
    entityId: str = ""
    entityKey: str = ""
    entityName: str = ""
    entityType: str = ""
    flags: str = ""
    regularExpression: str = ""
    entityXML: Optional[EntityXML] = None
    isSelected: str = "1"
    parentEntityId: str = "0"


@dataclass
class EntityDetailsResp:
    entityDetails: List[EntityDetails]


ML_ENTITY_UPDATE_TEMPLATE = """<DM2ContentIndexing_EntityDetailsReq opType="2">
    <entity entityId="{entityId}" entityKey="{entityKey}" entityName="{entityName}" entityType="{entityType}" flags="{flags}" regularExpression="{regularExpression}">
        <entityXML entityKey="{entityXML[entityKey]}" isSystemDefinedEntity="{entityXML[isSystemDefinedEntity]}" keywords="{entityXML[keywords]}">
            <classifierDetails datasetStorageType="{entityXML[classifierDetails][datasetStorageType]}" datasetType="{entityXML[classifierDetails][datasetType]}" trainDatasetURI="{entityXML[classifierDetails][trainDatasetURI]}" trainingStatus="{entityXML[classifierDetails][trainingStatus]}" modelURI="{entityXML[classifierDetails][modelURI]}" validationSamplesUsed="{entityXML[classifierDetails][validationSamplesUsed]}" classifierAccuracy="{entityXML[classifierDetails][classifierAccuracy]}" modelStorageType="{entityXML[classifierDetails][modelStorageType]}" trainingSamplesUsed="{entityXML[classifierDetails][trainingSamplesUsed]}" totalSamples="{entityXML[classifierDetails][totalSamples]}" modelGUID="{entityXML[classifierDetails][modelGUID]}">
                <err errorCode="{entityXML[classifierDetails][err][errorCode]}" errLogMessage="{entityXML[classifierDetails][err][errLogMessage]}" />
                <CAUsedInTraining caUrl="{entityXML[classifierDetails][CAUsedInTraining][caUrl]}" clientId="{entityXML[classifierDetails][CAUsedInTraining][clientId]}" cloudId="{entityXML[classifierDetails][CAUsedInTraining][cloudId]}" cloudName="{entityXML[classifierDetails][CAUsedInTraining][cloudName]}" lastModelTrainTime="{entityXML[classifierDetails][CAUsedInTraining][lastModelTrainTime]}" />
                <syncedContentAnalyzers/>
            </classifierDetails>
        </entityXML>
    </entity>
</DM2ContentIndexing_EntityDetailsReq>"""


def get_entities(params=None):
    exception_msg = (
        "Failed to get the entities from database. Please verify that Commserver is reachable."
    )
    exception_obj = Exception(exception_msg)
    try:
        if params == None:
            params = {}
        entities_xml = queryCS(1, **params)
        if entities_xml == "" or entities_xml is None:
            raise exception_obj
    except Exception as e:
        raise exception_obj

    entities_dict = parse(entities_xml)
    entities_dict = clean_dict(entities_dict)
    entities_dict = fix_datatypes(entities_dict)
    if (
        "DM2ContentIndexing_EntityDetailsResp" not in entities_dict
        or "entityDetails" not in entities_dict["DM2ContentIndexing_EntityDetailsResp"]
    ):
        raise exception_obj
    try:
        # in case of a single entity, we need to append it to a list for dataclass parsing
        if type(entities_dict["DM2ContentIndexing_EntityDetailsResp"]["entityDetails"]) != list:
            entity_details = entities_dict["DM2ContentIndexing_EntityDetailsResp"]["entityDetails"]
            entities_dict["DM2ContentIndexing_EntityDetailsResp"]["entityDetails"] = list()
            entities_dict["DM2ContentIndexing_EntityDetailsResp"]["entityDetails"].append(
                entity_details
            )
        entityDetailsResp = from_dict(
            data_class=EntityDetailsResp, data=entities_dict["DM2ContentIndexing_EntityDetailsResp"]
        )
        return entityDetailsResp.entityDetails
    except Exception:
        raise


@generic_cache(parent_cache_key="all_entities")
def get_entities_wrapper(cache_token=None):
    try:
        params = {"retry": True, "ee_cache_token": str(int(time()))}
        return get_entities(params=params)
    except Exception:
        raise


@generic_cache(parent_cache_key="entity")
def get_entity_wrapper(entityId, cache_token=None):
    try:
        entityDetails = get_entities_wrapper(cache_token=cache_token)
        entity_id_mapping = entities_mapping(entityDetails)
        entityId = int(entityId)
        if entityId in entity_id_mapping:
            return entity_id_mapping[entityId]
        else:
            raise EntityDoesNotExist(f"Entity [{entityId}] does not exists.")
    except EntityDoesNotExist:
        raise
    except Exception as e:
        raise e


def get_entity(entityId=None):
    try:
        params = {"retry": True, "ee_cache_token": str(int(time()))}
        entityDetails = get_entities(params=params)
        entity_id_mapping = entities_mapping(entityDetails)
        entityId = int(entityId)
        if entityId in entity_id_mapping:
            return entity_id_mapping[entityId]
        else:
            raise EntityDoesNotExist(f"Entity [{entityId}] does not exists.")
    except EntityDoesNotExist:
        raise
    except Exception:
        raise


def entities_mapping(entityDetails, entityType=None):
    entity_id_map = {}
    for entity in entityDetails:
        if entityType is None or str(entity.entityType) == str(entityType):
            entity_id_map[int(entity.entityId)] = entity
    return entity_id_map


def update_entity(entityDetails):
    try:
        if entityDetails.entityXML.classifierDetails.err is None:
            entityDetails.entityXML.classifierDetails.err = Error()
        entity_update_req = ML_ENTITY_UPDATE_TEMPLATE.format(**asdict(entityDetails))
        # update syncedContentAnalyzers
        if entityDetails.entityXML.classifierDetails.syncedContentAnalyzers is not None:
            content_analyzer_list = (
                entityDetails.entityXML.classifierDetails.syncedContentAnalyzers.contentAnalyzerList
            )
            synced_content_analyzers_ = "<syncedContentAnalyzers>"
            content_analyzer_template = '<contentAnalyzerList caUrl="{caUrl}" clientId="{clientId}" cloudId="{cloudId}" cloudName="{cloudName}" lastModelTrainTime="{lastModelTrainTime}" />'
            for synced_content_analyzer in content_analyzer_list:
                synced_content_analyzers_ += content_analyzer_template.format(
                    **asdict(synced_content_analyzer)
                )
            if len(content_analyzer_list) == 0:
                synced_content_analyzers_ += "<contentAnalyzerList />"
            synced_content_analyzers_ += "</syncedContentAnalyzers>"

            entity_update_req = entity_update_req.replace(
                "<syncedContentAnalyzers/>", synced_content_analyzers_
            )

        # msg_buffer = create_string_buffer(entity_update_req.encode("utf-8"))
        # msg_c = cast(msg_buffer, POINTER(c_char))
        # dll = getLoadedDLL()
        # dll.SendMsg.restype = POINTER(c_char)
        # res = dll.SendMsg(msg_c)
        # res = string_at(res)
        content_preview_port = loadRegValue(SOLR_TAGGER_PORT, 22000)
        params = {
            "EntityDetailsReq": entity_update_req,
        }
        url = f"http://localhost:{content_preview_port}/CvContentPreviewGenApp/rest/messagequeue/UpdateEntity"
        resp = UPDATE_ENTITY_SESSION.post(url, params)
        res = None
        if resp.status_code == requests.codes.OK:
            res = resp.content.decode("utf-8")
        else:
            raise Exception(f"Failed to update Entity.")

        result_dict = parse(res)
        error_code = result_dict["DM2ContentIndexing_EntityDetailsResp"]["err"]["@errorCode"]
        error_message = result_dict["DM2ContentIndexing_EntityDetailsResp"]["err"]["@errLogMessage"]
        return int(error_code), error_message
    except Exception:
        raise


def fix_datatypes(dict_):
    for key_value in list(dict_.items()):
        key, value = key_value
        if key in "contentAnalyzerList":
            if type(value) != list:
                old_val = value
                dict_[key] = list()
                dict_[key].append(old_val)
        else:
            if isinstance(value, dict):
                dict_[key] = fix_datatypes(value)
            if isinstance(value, list):
                for i, sub_dict in enumerate(value):
                    dict_[key][i] = fix_datatypes(sub_dict)
    return dict_


def clean_dict(dict_):
    for key_value in list(dict_.items()):
        key, value = key_value
        if key.startswith("@"):
            dict_[key[1:]] = value
            del dict_[key]
            key = key[1:]
        if isinstance(value, dict):
            dict_[key] = clean_dict(value)
        if isinstance(value, list):
            for i, sub_dict in enumerate(value):
                dict_[key][i] = clean_dict(sub_dict)
    return dict_
