# -*- coding: utf-8 -*-

# --------------------------------------------------------------------------
# Copyright Commvault Systems, Inc.
# See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------

"""Does all the Operation for Amazon vm"""

import time
from AutomationUtils import machine
from VirtualServer.VSAUtils.VMHelper import HypervisorVM
from VirtualServer.VSAUtils import VirtualServerUtils
from AutomationUtils import logger

class AmazonVM(HypervisorVM):
    """
    This is the main file for all  Amazon VM operations
    """

    # pylint: disable=too-many-instance-attributes
    # VM property mandates many attributes.
    def __init__(self, hvobj, vm_name):
        """
        Initialization of AWS vm properties

        Args:
            hvobj               (obj):  Hypervisor Object

            vm_name             (str):  Name of the VM
        """
        super(AmazonVM, self).__init__(hvobj, vm_name)
        self.host_machine = machine.Machine()
        self.server_name = hvobj.server_host_name
        self.vm_name = vm_name
        self.instance = None
        self._aws_access_key = self.hvobj._aws_access_key
        self._aws_secret_key = self.hvobj._aws_secret_key
        self.aws_region = self.hvobj.aws_region  # TODO: check if this can be removed
        self.instance, self.guid, self.ip, self.guest_os, self.volumes, self._disk_list, self.disk_count, \
        self.no_of_cpu, self.vpc, self.subnet, self.nic, self.ec2_instance_type, \
        self.iam = (None for _ in range(13))
        self.security_groups = []
        self.disk_dict, self.tags, self.volume_tags = ({} for _ in range(3))
        self.memory = 0
        self.iam = None
        self._basic_props_initialized = False
        self.connection = self.hvobj.connection
        self.update_vm_info()

        # TODO: don't need this right now
        self.availability_zone = None
        self.termination_protection = None

    def __eq__(self, other):
        """compares the source vm and restored vm"""
        import copy
        tags1 = copy.deepcopy(self.tags)
        tags2 = copy.deepcopy(other.tags)
        if not self.validate_name_tag or not other.validate_name_tag:
            tags1 = list(filter(lambda i: i['Key'].lower() != 'name', tags1))
            tags2 = list(filter(lambda i: i['Key'].lower() != 'name', tags2))
        tags1 = sorted(tags1, key=lambda i: i['Key'])
        tags2 = sorted(tags2, key=lambda i: i['Key'])
        self.log.info('tags1 :{}'.format(tags1))
        self.log.info('tags2 :{}'.format(tags2))
        self.log.info("Source VM:{0}::\nSecurity Group:{1}\nInstance type:{2}\niam:{3}".format(
            self.vm_name, self.security_groups, self.ec2_instance_type, self.iam
        ))
        self.log.info("Destination VM:{0}::\nSecurity Group:{1}\nInstance type:{2}\niam:{3}".format(
            other.vm_name, other.security_groups, other.ec2_instance_type, other.iam
        ))
        return (tags1 == tags2 and
                self.security_groups == other.security_groups and
                self.ec2_instance_type == other.ec2_instance_type and
                self.iam == other.iam)

    class LiveSyncVmValidation(object):
        def __init__(self, vmobj, schedule, replicationjob=None):
            self.vm = vmobj
            self.schedule = schedule
            self.replicationjob = replicationjob
            self.log = logger.get_log()

        def __eq__(self, other):
            """validates livesync replication"""

            config_val = (int(self.vm.vm.no_of_cpu) == int(other.vm.vm.no_of_cpu) and
                          int(self.vm.vm.disk_count) == int(other.vm.vm.disk_count) and
                          int(self.vm.vm.memory) == int(other.vm.vm.memory))
            if not config_val:
                return False

            # network and security group validation
            scheduleprops = self.schedule.virtualServerRstOptions
            schdetails = scheduleprops['diskLevelVMRestoreOption']['advancedRestoreOptions']
            for vmdetails in schdetails:
                if vmdetails['name'] == self.vm.vm.vm_name:
                    if 'nics' in vmdetails:
                        if vmdetails['nics'][0]['subnetId'] != other.vm.vm.subnet:
                            return False
                    if 'securityGroups' in vmdetails:
                        if vmdetails['securityGroups'][0]['groupId'] != other.vm.vm.security_groups[0]:
                            return False

            return self.vm.vm.ec2_instance_type == other.vm.vm.ec2_instance_type

    def _set_credentials(self, os_name):
        """
        set the credentials for VM by reading the config INI file.
        Overridden because root login is not possible in out of place restored AWS instance.
        """

        # first try root credentials
        sections = VirtualServerUtils.get_details_from_config_file(os_name.lower())
        user_list = sections.split(",")
        incorrect_usernames = []
        for each_user in user_list:
            self.user_name = each_user.split(":")[0]
            self.password = VirtualServerUtils.decode_password(each_user.split(":")[1])
            try:
                vm_machine = machine.Machine(self.vm_hostname,
                                             username=self.user_name,
                                             password=self.password)
                if vm_machine:
                    self.machine = vm_machine
                    return
            except:
                incorrect_usernames.append(each_user.split(":")[0])

        # if root user doesn't work (for Linux only), try ec2-user with key
        sections = VirtualServerUtils.get_details_from_config_file('aws_linux')
        user_list = sections.split(",")
        keys = VirtualServerUtils.get_details_from_config_file('aws_linux', 'keys')
        key_list = keys.split(",")
        incorrect_usernames = []
        for each_user in user_list:
            self.user_name = each_user.split(":")[0]
            self.password = each_user.split(":")[1]
            # self.key_filename = key_list
            try:
                run_as_sudo = self.user_name == "ec2-user"
                vm_machine = machine.Machine(self.vm_hostname, username=self.user_name,
                                             password=self.password, key_filename=key_list,
                                             run_as_sudo=run_as_sudo)
                if vm_machine:
                    self.machine = vm_machine
                    return
            except:
                incorrect_usernames.append((each_user.split(":")[0]))

        self.log.exception("Could not create Machine object for machine : '{0}'! "
                           "The following user names are incorrect: {1}"
                           .format(self.vm_hostname, incorrect_usernames))

    def clean_up(self):
        """
        Clean up the VM and ts reources.

        Raises:
            Exception:
                When cleanup failed or unexpected error code is returned

        """

        try:

            self.log.info("Powering off VMs after restore")
            self.power_off()

        except Exception as exp:
            self.log.exception("Exception in Cleanup")
            raise Exception("Exception in Cleanup:" + str(exp))

    def _get_vm_info(self):
        """
        Get the basic or all or specific properties of VM

        Args:
                prop                (str):  basic, All or specific property like Memory

                extra_args          (str):  Extra arguments needed for property listed by ","

        Raises:
            Exception:
                if failed to get all the properties of the VM

        """
        try:
            if ' ' in self.vm_name.strip():
                raise Exception("Amazon Instance name should not have any spaces")
            _resource = self.connection.resource('ec2')
            instances = _resource.instances.filter(Filters=[
                {'Name': 'tag:Name',
                 'Values': [self.vm_name]
                 },
            ],
                DryRun=False
            )
            count = 0
            for instance in instances:
                if instance.state['Code'] != 48:
                    count = count + 1
                    if count > 1:
                        self.log.Error('multiple instances with same name')
                        raise Exception('multiple instances with same name')
                    self.guid = instance.id
                    self.power_state = instance.state['Code']
                    self.ip = instance.private_ip_address
                    self.guest_os = instance.platform
                    if not self.guest_os:
                        self.guest_os = 'unix'
                    self.instance = instance
            if count == 0:
                self.log.error('No Instance found by this name : {0}'.format(self.vm_name))
                raise Exception('No Instance found by this name')
            self._basic_props_initialized = True
        except Exception as err:
            self.log.exception("Failed to Get basic info of the instance")
            raise Exception(err)

    def update_vm_info(self, prop='Basic', os_info=False, force_update=False):
        """
        Fetches all the properties of the VM

        Args:
            prop                (str):  Basic - Basic properties of VM like HostName,
                                                especially the properties with which
                                                VM can be added as dynamic content

                                        All   - All the possible properties of the VM

            os_info             (bool): To fetch os info or not

            force_update        (bool): to refresh all the properties always
                    True : Always collect  properties
                    False: refresh only if properties are not initialized

        Raises:
            Exception:
                if failed to update all the properties of the VM

        """
        try:
            if not self._basic_props_initialized or force_update:
                self._get_vm_info()
            if self.power_state == 48:
                self.log.error(self.vm_name, "has been terminated. Unable to create "
                                             "VMHelper object")
                return
            elif self.power_state in (0, 32, 64):
                time.sleep(120)
                self.power_on()
                self._get_vm_info()
            elif self.power_state == 80:
                self.power_on()
                self._get_vm_info()
            if os_info or prop == 'All':
                self.vm_guest_os = self.guest_os
                self.get_other_detail()
                self.get_drive_list()
                self.set_security_groups()
                self.set_volume_tags()
                self.get_memory()
                self.disk_count = len(self.disk_list)
        except Exception as err:
            self.log.exception("Failed to Get info of the instance")
            raise Exception(err)

    def power_on(self):
        """
        Power on the VM.

        Raises:
            Exception:
                When power on fails or unexpected error code is returned

        """

        try:
            if self.instance.state['Code'] != 48:
                if self.instance.state['Code'] != 16:
                    self.instance.start()
                    time.sleep(180)
            else:
                self.log.error("Power On failed. Instance has been TERMINATED already")
            if self.instance.state['Code'] != 16:
                raise Exception("Instance not powered On")
        except Exception as exp:
            self.log.exception("Exception in PowerOn")
            raise Exception("Exception in PowerOn:" + str(exp))

    def power_off(self):
        """
        Power off the VM.

        Raises:
            Exception:
                When power off fails or unexpected error code is returned

        """

        try:
            if self.instance.state['Code'] != 48:
                if self.instance.state['Code'] != 80:
                    self.instance.stop()
                    time.sleep(180)
            else:
                self.log.error("Power Off failed. Instance has been terminated already")
            if self.instance.state['Code'] != 80:
                raise Exception("Instance not powered off")
        except Exception as exp:
            self.log.exception("Exception in PowerOfff")
            raise Exception("Exception in PowerOff:" + str(exp))

    def delete_vm(self):
        """
        Terminates the ec2 instance.

        Raises:
            Exception:
                When deleting the instance fails or unexpected error code is returned

        """

        try:
            if self.instance.state['Code'] != 80:
                if self.instance.state['Code'] == 16:
                    self.power_off()
            self.instance.terminate()
            time.sleep(60)
            if self.instance.state['Code'] != 48:
                raise Exception("Instance not deleted")
        except Exception as exp:
            self.log.exception("Exception in Delete")
            raise Exception("Exception in Delete:" + str(exp))

    @property
    def disk_list(self):
        """
        To fetch the disk in the VM

        Returns:
            disk_list           (list): List of volumes in AWS instance

        """
        try:
            self.volumes = self.instance.volumes.all()
            self._disk_list = [v.id for v in self.volumes]
            if self._disk_list:
                return self._disk_list
            else:
                return []
        except Exception as exp:
            self.log.exception("Exception in getting disk list")
            raise Exception("Exception in getting disk list" + str(exp))

    def set_security_groups(self):
        """
        Sets the security groups associated with the AWS ec2 instance

        Raises:
            Exception:
                issues when unable to get the security groups
        """
        try:
            if not self.security_groups:
                for sgi in self.instance.security_groups:
                    self.security_groups.append(sgi['GroupId'])
        except Exception as err:
            self.log.exception("Failed to get security groups")
            raise Exception(err)

    def set_volume_tags(self):
        """
        Sets the tags associated with each volume for the instance and stores in a dict with
        each key as the volume id

        Raises:
            Exception:
                issues when unable to get the tags of a volume
        """
        try:
            for _vol in self.volumes:
                _resource = self.connection.resource('ec2')
                volume = _resource.Volume(_vol.id)
                _tag_dict = {}
                for _v in volume.tags:
                    _key = _v['Key'].strip()
                    if _key:
                        try:
                            _value = _v['Value'].strip()
                        except IndexError:
                            _value = ''
                    _tag_dict[_v['Key'].strip()] = _v['Value'].strip()
                self.volume_tags[_vol.id] = _tag_dict
        except Exception as err:
            self.log.exception("Failed to get volume tags")
            raise Exception(err)

    def get_other_detail(self):
        """
        Sets the tags associated with each volume for the instance and stores in a dict with
        each key as the volume id

        Raises:
            Exception:
                issues when unable to get the tags of a volume
        """
        try:
            self.volumes = self.instance.volumes.all()
            self.vpc = self.instance.vpc.id
            self.subnet = self.instance.subnet.id
            self.nic = self.instance.network_interfaces[0].id
            self.ec2_instance_type = self.instance.instance_type
            _sts = self.connection.client('sts')
            self.iam = _sts.get_caller_identity()["Account"]
            self.tags = self.instance.tags
            self.no_of_cpu = self.instance.cpu_options['CoreCount']
            for device in self.instance.block_device_mappings:
                self.disk_dict[device['Ebs']['VolumeId']] = device['DeviceName']
        except Exception as err:
            self.log.exception("Failed to get other detail")
            raise Exception(err)

    def get_memory(self):
        """
        Gets the memory of the vm
        Raises:
            Exception:
                issues when unable to get memory of the vm
        """
        try:
            if self.guest_os.lower() == 'windows':
                _output = self.machine.execute_command(
                    'get-ciminstance -class "cim_physicalmemory" | % {$_.Capacity}')
                self.memory = int(_output.formatted_output)/1024/1024/1024
            else:
                _output = self.machine.execute_command('cat /proc/meminfo | grep DirectMap')
                _sum = 0
                for _data in _output.formatted_output:
                    _sum += int(_data[1])
                self.memory = _sum/1024/1024
        except Exception as err:
            self.log.exception("Failed to fetch memory of the vm: {}".format(self.vm_name))
            raise Exception(err)
