# Python imports
import pprint
from subprocess import Popen, PIPE
from Queue import Queue
from threading import Thread
import sys

# Project imports
import cvmanager_catalog
import cvmanager_task_step
import cvmanager_task
import cvmanager_task_arg
import cvmanager_utils
import cvmanager_defines
import wrappers
import common
from HsObject import hs_node
from HsObject import hs_defines
from HsObject import hs_utils
from HsObject import hs_ssh

LOG = None
DYNAMIC_DEFINES = sys.modules[cvmanager_defines.DYNAMIC_DEFINE_NAME]
python_path = None  # This is for debugging purposes only; use with caution. python_path = None


class Task(cvmanager_task.TaskObject):
    """This task can be executed on an existing node of the cluster, used for adding new remote nodes into
    this same existing cluster.  It will look for any nodes broadcasting their availability, get their configuration,
    compare it with the current cluster configuration

    This task does not require any input.  By default it specifies DO_NOT_SAVE option which does not save the status
    of the running task.  Therefore, every time the task is executed, it starts fresh.

    """
    # Task input metadata
    task_args = {
        'do_not_save': cvmanager_task_arg.CommonArg.DO_NOT_SAVE,

        'interactive': cvmanager_task_arg.TaskArg('interactive', bool, required=False, default_value=True,
                                                  valid_values=[True, False],
                                                  description="Interactive mode.  Does not require any inputs."),

        'use_avahi': cvmanager_task_arg.CommonArg.USE_AVAHI,

        'node_filter': cvmanager_task_arg.TaskArg('node_filter', list, required=False,
                                                  description="Filter the avahi browse to only these specific list of "
                                                              "nodes. They will always be browsed regardless of avahi "
                                                              "status."),
        'discover_hardware': cvmanager_task_arg.TaskArg('discover_hardware', bool, required=False, default_value=True,
                                                        valid_values=[True, False], description="If True, the avahi "
                                                        "discovered nodes will be contacted and queried for hardware "
                                                        "information.  This could take a long time if you have many"
                                                                                                " avahi nodes"),
    }

    def set_process(self, process):
        """
        This is required to be defined in each and every Task().  It's the definition for the order to call the steps.
        It has 3 different processes available; pre_process, main_process, and post_process.
        All the process types are treated exactly the same during runtime...it's merely logical grouping to show
        different phases to the user.  They are always run in the order, pre -> main -> post.

        :param process: cvmanager_task_process.TaskProcess.  Not required, parent task sets it.
        :return: None
        """
        process.pre_process = [
            self.check_supported,
            self.set_network_routes
        ]

        process.main_process = [
            self.get_local_node_config,
            self.discover_remote_nodes,
            self.validate_remote_nodes
        ]

        process.post_process = [
            self.format_output_for_workflow
        ]

    @cvmanager_task_step.OptionStep.run_always
    def set_network_routes(self, *args, **kwargs):
        """
        In order for the avahi-browse to find newly imaged nodes from already configured nodes, we need to set a
        network route from the default_gw of this existing, configured node, to the avahi network.

        1) Using [cvnwconfigmgr.py get], we get this local node network configuration.
        2) Iterate through all local interfaces to find which has the property:
            is_defaultgw: true
        3) add the routes to avahi, through this interface.

        :param args:
        :param kwargs:
        :return:
        """
        # Get the local node network config
        local_node = hs_node.HyperScaleNode()
        local_node_network_json = local_node.get_node_network_json()

        # Iterate through the local interfaces to find the default_gw
        devices_to_add_route = []
        for interface in local_node_network_json['interfaces']:
            if interface.get('is_defaultgw', False) == 'true' or interface.get('is_defaultgw', False):
                # This interface is the default gateway.
                gateway_device = interface.get("device", None)
                devices_to_add_route.append(gateway_device)

                # Check if this device is part of any bridge subinterfaces.  If so, need to add route for these bridges
                bridges = [bridge_info.get('bridge', '') for bridge_info in local_node_network_json.get('bridges', {})
                           if gateway_device in bridge_info.get('subinterfaces', [])]

                if len(bridges) > 0:
                    self.log.debug("Including bridges [{0}] in device routes.".format(",".join(bridges)))
                    devices_to_add_route.extend(bridges)

                break
        else:
            self.log.error("Unable to find the default gateway device name.")
            return False

        # Add the routes.
        routes = [
            'route add -net 169.254.0.0 netmask 255.255.0.0 dev {0} metric 99',
            'route add default dev {0} metric 99'
        ]
        for route_command in routes:
            for device in devices_to_add_route:
                process = Popen(route_command.format(device), stdout=PIPE, stderr=PIPE, shell=True)
                output, error = process.communicate()

                if process.returncode == 7 and 'File exists' in error:
                    # This route is already there.
                    continue

                if not process.returncode == 0:
                    self.log.error("Failed adding route: {0}\n{1}\n{2}".format(
                        route_command.format(device), output, error))
                    return False

        return True

    @cvmanager_task_step.OptionStep.run_always
    def check_supported(self, *args, **kwargs):
        # Discovering nodes to add is only supported for HyperScale X
        global LOG

        LOG = self.log

        local_node = hs_node.HyperScaleNode(wrappers.get_hostname())
        setattr(self, "LOCAL_NODE", local_node)

        if not local_node.node_type == hs_defines.NodeTypes.HEDVIG:
            self.log.error("Adding new nodes only supported on HyperScale X based architectures.")
            return False

        return True

    @cvmanager_task_step.TaskStep
    def get_local_node_config(self, *args, **kwargs):
        """
        Get the drive information on this node; drive count & size.  Dictionary of drive information.
        Having the decorator of OptionStep.run_always will ensure that even when resuming failed attempts, this step
        will always run.
        """
        local_node = getattr(self, 'LOCAL_NODE')

        # Get the local node hardware configuration, using cvhyperscale for the block devices
        local_node_xml = local_node.get_node_block_device_xml(python_path=python_path)

        # Read the relevant information from the config file
        node_info = {
            'BLOCK_DEVICE': hs_utils.get_block_device_info_from_xml(local_node_xml)
        }

        # Save this information.
        setattr(self, 'LOCAL_NODE_INFO', node_info)
        return True

    @cvmanager_task_step.TaskStep
    def discover_remote_nodes(self, *args, **kwargs):
        # Using avahi, get all remote nodes which are broadcasting.

        # Run the avahi browse to locate broadcasting nodes. Only start if NOT using avahi at the task level.
        (ret, available_node_info) = hs_utils.browse_avahi_nodes(start_service=not self.kwargs.get('use_avahi', False))
        if not ret == 0:
            self.log.error("Failed to browse avahi nodes, please make sure avahi-daemon is running on this machine.")
            return False

        # Enumerate through all node detected, and only get ones where status is NOT HyperScaleConfigured.
        discovered_nodes = {}
        avahi_node_info = available_node_info.read_config_file()
        for node in avahi_node_info.sections():
            """Each section in avahi output is a detected node.  Check the status of that node.
            Looking at the avahi output, the [node] is the [section] for the purposes of this loop.
            [ym6d010364]
            ip = 172.24.33.97
            host = hs3300devb0103
            status = HyperScaleConfigured
            
            node = ym6d010364
                        
            """
            if self.kwargs.get('node_filter', False):
                # This is to filter the results; use the filter only as the requirement.
                if not (avahi_node_info.get(node, 'host') in self.kwargs.get('node_filter', [])) and \
                   not (node in self.kwargs.get('node_filter', [])):
                    continue
            else:
                if avahi_node_info.get(node, 'status') == 'HyperScaleConfigured':
                    continue

            # Turn the discovered avahi nodes into JSON.
            discovered_nodes[node] = cvmanager_utils.get_json_from_config_section(avahi_node_info, node)

        # Test connection to each node, using SSH, and trying to see if it's 2.0 or not.  Only do this if we aren't
        # performing hardware discovery.
        if not self.kwargs.get('discover_hardware', False):
            # If we're discovering hardware, we can skip this because deeper checks will be performed anyway.
            discovered_nodes = self.test_connection_to_avahi_nodes(discovered_nodes)

        if self.kwargs.get('interactive', False):
            # Interactive mode, ask user to select the nodes to discover and validate.
            display_nodes = True
            while True:
                selected_nodes = self.get_nodes_from_user(discovered_nodes, display_nodes)
                if not all(node in discovered_nodes for node in selected_nodes):
                    # User entered a node which was not discovered.  Check to see if its an IP or hostname, if it is
                    # then we can assume DHCP was used, and we can use that to configure.  Check connection.

                    ip_addresses = map(lambda ip: ip if hs_utils.check_ip(ip) else None, selected_nodes)
                    host_names = map(lambda hn: hn if common.is_valid_hostname(hn) else None, selected_nodes)

                    ip_addresses.extend(host_names)
                    validated_inputs = list(set([x for x in ip_addresses if x is not None]))

                    if len(validated_inputs) == len(selected_nodes):
                        # All the user inputs are valid IP addresses OR hostnames; check them if they're reachable.
                        if all(s_node in validated_inputs for s_node in selected_nodes):
                            # All the inputs are valid.  Test them.
                            selected_nodes_tested = self.test_connection_to_avahi_nodes(selected_nodes)

                        if all(s_node in selected_nodes_tested.keys() for s_node in selected_nodes):
                            # The IP or hostname is valid, can continue with deep dive.
                            discovered_nodes.update(selected_nodes_tested)
                            break

                    display_nodes = False
                    print("You've entered a node which has not been discovered or is not a valid IP address or FQDN. "
                          "Please try again.")
                    raw_input("Press Enter to continue...")
                    continue
                else:
                    break
        else:
            # Non-interactive mode, traverse all discovered nodes and get all the hardware configuration if required.
            selected_nodes = discovered_nodes.keys()

        nodes = {}
        if not self.kwargs.get('discover_hardware', False):
            # The command center entry point will perform 2 discovery; 1 with hardware, 1 without.  This is controlled
            # via the discover_hardware input.  If its false, we just send back the avahi discovered nodes.
            self.log.info("Only performed avahi browse and not collecting remote node hardware information.")
            for node in selected_nodes:
                node_json = discovered_nodes.get(node)
                node_info = {
                    'BLOCK_DEVICE': None,
                    'AVAHI': node_json,
                    'ARCH': None
                }
                nodes[node] = node_info
        else:
            # Do the deeper dive to get remote node hardware information; this is parallel execution.
            self.log.info("Reading hardware configuration from [{0}] remote nodes..."
                          "this may take some time, please wait...".format(len(selected_nodes)))

            nodes = self.get_remote_node_hardware_information(discovered_nodes, selected_nodes)

        setattr(self, 'REMOTE_NODES', nodes)
        return True

    @cvmanager_task_step.TaskStep
    def validate_remote_nodes(self, *args, **kwargs):
        # Compare remote node information with local node information to ensure these new nodes are compatible.

        valid_nodes = {}
        # Loop through all the collected nodes, and validate if they can be added to this cluster.
        for node, node_info in getattr(self, 'REMOTE_NODES', {}).items():
            if self.kwargs.get('discover_hardware', False):
                if not self.validate_drive_count(node_info):
                    self.log.warning("Node [{0}] does not satisfy disk drive count requirements.".format(node))
                    continue

                if not self.validate_drive_size(node_info):
                    self.log.warning("Node [{0}] does not meet minimum drive sizing for this cluster.".format(node))
                    continue

                if not self.validate_drive_mounted(node_info):
                    self.log.warning("Node [{0}] does not have all required data drives mounted, "
                                     "please correct this and ensure all data drives are mounted prior to discovery.".
                                     format(node))
                    continue

            # Node was validated successfully.
            valid_nodes[node] = node_info

        if len(valid_nodes) == 0:
            # We did not find any valid nodes to configure.
            self.log.error("Did not discover any valid nodes for adding into this cluster.")
            return False

        self.log.info("Discovered [{0}] available node(s) for adding into this cluster.".format(len(valid_nodes)))

        setattr(self, 'VALIDATED_NODES', valid_nodes)

        return True

    @cvmanager_task_step.TaskStep
    def format_output_for_workflow(self, *args, **kwargs):
        # Write the node information into the catalog, so that workflow and future tasks can read from it.
        task_catalog = cvmanager_catalog.Catalog()
        available_node_info = task_catalog.get_file("Available_Nodes.json", True)
        available_node_info_serial_nums = task_catalog.get_file("Available_Nodes_Workflow.json", True)
        self.log.info("Saving available nodes to: {0}".format(available_node_info.file_path))

        # Do any formatting needed for the validated information about available nodes.
        node_info = getattr(self, 'VALIDATED_NODES', {})
        if len(node_info) == 0:
            # Something went wrong.  Abort.
            self.log.error("Unable to find any validated remote nodes for addition to this cluster.")
            return False

        # Save the validated information into the catalog.
        available_node_info.write(node_info, cvmanager_catalog.FileType.JSON)

        workflow_data = []
        for node_sn, node_data in node_info.items():
            workflow_data.append({"serial_number": node_sn})
        available_node_info_serial_nums.write(workflow_data, cvmanager_catalog.FileType.JSON)

        return True

    def validate_drive_mounted(self, remote_node_info):
        """
        All data drives on the remote node should be mounted, if not, we should not add this node.

        :param remote_node_info:
        :return:
        """
        remote_block_device = remote_node_info['BLOCK_DEVICE']

        # Ensure that all block devices are mounted properly.
        all_mounted_devices = []
        for device, device_info in remote_block_device.items():
            mount_point_info = device_info.get("MOUNT_POINT", None)
            remote_node = hs_node.HyperScaleNode(remote_node_info.get('AVAHI', None).get('ip', None))

            if not remote_node.check_if_device_mounted(mount_point_info.get('blkdev', None),
                                                       mount_point_info.get('mntpath', None),
                                                       use_default_image_cred=True):
                continue
            all_mounted_devices.append(device)

        if not len(all_mounted_devices) == len(remote_block_device):
            return False

        return True

    def validate_drive_count(self, remote_node_info):
        """
        Ensure the remote node # of drives is equal to the current node # of drives.  We cannot add a node to the
        cluster which has a different number of drives.

        :param remote_node_info:
        :return:
        """
        local_block_device = getattr(self, 'LOCAL_NODE_INFO')['BLOCK_DEVICE']
        remote_block_device = remote_node_info['BLOCK_DEVICE']

        # Ensure local node & remote node have same number of sas data drives (cvdrive=data)
        if not len(remote_block_device) == len(local_block_device):
            self.log.warning("Remote node does not have the same number of drives [{0}], as existing cluster "
                             "node drives[{1}].".format(len(remote_block_device),len(local_block_device)))
            return False

        return True

    def validate_drive_size(self, remote_node_info):
        """
        Ensure the size of drives on the remote node, is not less than the size of the current node.  They can
        be larger, but we should not permit smaller drives.

        Therefore, all remote disks should be at least greater OR equal than the largest local drive.  We should
        never have remote drive being added that is smaller than the local node.  But we should allow bigger drives
        to be added.

        Node information is JSON and always structured as follows.
        {
            'sas#': {
                'cvdrive': 'data',
                'devname': '/dev/sde',
                'type': 'sas',
                'devsize': '9314G',
                'devalias': 'sas#'
            }
        }

        :param remote_node_info: json - Node information gathered from XML of remote node.
        :return:
        """
        local_block_device = getattr(self, 'LOCAL_NODE_INFO')['BLOCK_DEVICE']
        remote_block_device = remote_node_info['BLOCK_DEVICE']

        # Get the largest local data drive.
        largest_local_disk = max([vv.rstrip('G') for k, v in local_block_device.items()
                                  for kk, vv in v.items() if kk == 'devsize'])

        # Get the smallest remote data drive.
        remote_disk_sizes = [vv.rstrip('G') for k, v in remote_block_device.items()
                             for kk, vv in v.items() if kk == 'devsize']

        # Ensure ALL remote disk sizes are the same.  Just use first one as condition, since they should all be equal
        expected_value = next(iter(remote_disk_sizes))
        if not all(size == expected_value for size in remote_disk_sizes):
            self.log.error("Remote node has mismatched data drive sizes. Node information:\n{0}".format(
                pprint.pformat(remote_block_device)
            ))
            return False

        smallest_remote_disk = min(remote_disk_sizes)
        if not int(smallest_remote_disk) >= int(largest_local_disk):
            self.log.error("The smallest remote drive size [{0}] must be a minimum size of [{1}].".format(
                smallest_remote_disk, largest_local_disk
            ))
            return False

        return True

    def get_nodes_from_user(self, discovered_nodes, display_node_list=True):
        if display_node_list:
            for i, (node, avahi_info) in enumerate(discovered_nodes.items(), 1):
                print("{0}) {1}({2})".format(i, node, avahi_info.get('host', '')))

        print("\nEnter the <node serial number>(s) to configure; comma separated. * may be used as a "
              "wildcard character for serial number(s).\n"
              "The above format is:\n#) <node serial number>(<node hostname>)\n\n"
              "Example List)\n"
              "1) 123456(testnode1)\n"
              "2) 123457(testnode2)\n"
              "3) node 3 with a space(testnode3)\n"
              "4) 654321(testnode4)\n\n"
              "Example 1)\n\tSelected Node(s): 12345*\t\t(Selects node 1 & 2 using wildcard)\n"
              "Example 2)\n\tSelected Node(s): 123456,654321,node 3 with a space\t\t(Selects node 1, 4, 3)\n")
        input_nodes = raw_input("Selected Node(s): ")
        if '*' == input_nodes:
            # Select all nodes.
            selected_nodes = discovered_nodes.keys()
        elif '*' in input_nodes:
            # Pattern matching.
            selected_nodes = [key for key in discovered_nodes.keys() if input_nodes.strip('*') in key.lower()]
        else:
            selected_nodes = input_nodes.split(',')
        return selected_nodes

    def get_remote_node_hardware_information(self, discovery_info, nodes):
        """
        In parallel, go to all nodes over ssh using default credentials (not passwordless ssh) and collect the
        hardware device information for this node.

        :param discovery_info:
        :param nodes:
        :return:
       """
        node_queue = Queue()
        new_node_dict = {}

        def worker(valid_node_dict):
            while True:
                try:
                    (node_serial_number, node_info) = node_queue.get()

                    node_device_info = self.get_node_device_info(node_serial_number, node_info)
                    if node_device_info:
                        valid_node_dict[node_serial_number] = node_device_info

                except Exception, err:
                    pass
                finally:
                    node_queue.task_done()

        for selected_node in nodes:
            node_queue.put((selected_node, discovery_info[selected_node]))

        for i in range(cvmanager_defines.MAX_REMOTE_NODE_SSH_CONNECTIONS):
            t = Thread(target=worker, args=(new_node_dict,))
            t.daemon = True
            t.start()

        # block until all tasks are done
        node_queue.join()

        return new_node_dict

    def test_connection_to_avahi_nodes(self, nodes):
        """
        In parallel, go to all avahi nodes over ssh using default credentials (not passwordless ssh) and check if
        we can add this node.
        :param nodes:
        :return:
        """
        node_queue = Queue()
        new_node_dict = {}

        def worker(valid_node_dict):
            while True:
                try:
                    q_item = node_queue.get()

                    if self.test_ssh_conn(**q_item):
                        valid_node_dict.update(q_item)
                except Exception, err:
                    pass
                finally:
                    node_queue.task_done()

        if isinstance(nodes, list):
            # If we get a list of ips or hostnames, try connecting to them.
            for n in nodes:
                node_queue.put({n: {'ip': n}})
        else:
            for k, v in nodes.items():
                node_queue.put({k: v})

        for i in range(cvmanager_defines.MAX_REMOTE_NODE_SSH_CONNECTIONS):
            t = Thread(target=worker, args=(new_node_dict,))
            t.daemon = True
            t.start()

        # block until all tasks are done
        node_queue.join()

        return new_node_dict

    def get_node_device_info(self, node, node_json):
        try:
            # This is an available node; get all its options and values.
            remote_node = hs_node.HyperScaleNode(node_json.get('ip', None))
            remote_node_xml = remote_node.get_node_block_device_xml(python_path=python_path)

            if not remote_node_xml:
                # Unable to get the remote node XML; drop this node from discovered list and consider it not valid node.
                return False

            # 9/16/2020: EF - Removed this; doesn't look like we use ARCH info anywhere and its causing issues.
            # remote_node_hw_info = remote_node.get_arch_hw_info(json_format=True)
            remote_node_hw_info = remote_node.get_node_ips(json_format=True, python_path=python_path)

            # Read the relevant information from the config file
            node_info = {
                'BLOCK_DEVICE': hs_utils.get_block_device_info_from_xml(remote_node_xml),
                'AVAHI': node_json,
                'ARCH': remote_node_hw_info  # Is this used anywhere?  Remove if no issues found.
            }
            return node_info
        except Exception, err:
            # If any node throws an exception, just drop it off the list of nodes.
            self.log.error("Failed getting config for node: {0}".format(node))
            self.log.error("Exception: {0}".format(err.message))

    def test_ssh_conn(self, **kwargs):
        # Attempt connection to node with default credentials; if doesn't connect, don't include this node.
        # Only allow 2 seconds for any connection, etc.

        # dict key is node name, dict value is ip address
        try:
            k, v = kwargs.popitem()

            ssh_conn = hs_ssh.RemoteSSH(v['ip'], username=cvmanager_defines.DEFAULT_IMAGE_USER,
                                        password=cvmanager_defines.DEFAULT_IMAGE_PASSWORD,
                                        timeout=hs_defines.SSH_TIMEOUT, banner_timeout=hs_defines.SSH_BANNER_TIMEOUT)

            if not ssh_conn.hs_2dot0(python_path=python_path):
                self.log.debug("Node [{0}] is not HyperScale 2.x node.".format(k))
                return False

            return True

        except Exception, err:
            return False
