#!/usr/bin/python3

# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# Module implements raid configuration for Linux software raid

import getopt
import sys
import time
import subprocess
import os
import re
from stordev import BlockDev
from stordev import init_stor
from stordev import reinit_stor
from stordev import get_block_devs
from stordev import mkfs
from stordev import issue_cmd
from stordev import print_block_devs
from raid_configurator import RaidConfigurator
from nv_subps import SubPs

""" List of MD raid present on the system """
raid_list = []

class MD_CONST:
    SUCCESS = 0
    RAID_EXIST = 1
    CREATE_FAIL = 2
    INVALID_RAID_LEVEL = 3
    BOOT_DEVICE = 4
    RAID_NOT_FOUND = 5
    NOT_ENOUGH_DRIVES = 6
    OP_ABORTED = 7
    FS_CREATE_FAIL = 8
    CACHEFILE_START_FAIL=9

class MDRaidConfigurator(RaidConfigurator):
    """
    Class implements software raid configuration functionality, create, rebuild,
    and convert raid level.
    """
    def __init__(self, script_name):
        RaidConfigurator.__init__(self, "MD Raid Configurator", script_name)
        self.data_array_name = "md0"
        self.manual_mount = False
        self.cmd_options = "ihcrf5m:"
        # New default behavior for all DGX platforms: enable fscache if
        # RAID0, otherwise disable it.
        self.enable_fscache = True

    def print_usage(self):
        ''' How to use this script'''
        print("\nUsage :\n")
        print("    <script_name> <options>\n")
        print("    <options>:\n")
        print("      <-h>       Displays help")
        print("      <-c>  Create RAID array with default raid level 0")
        print("              <-5>  create raid5 instead of default raid0")
        print("                    <configure_raid_array.py -c -5>")
        print("              <-f>  Use -f option with -c to force removal of inactive arrays")
        print("                    <configure_raid_array.py -c -f>")
        print("                    <configure_raid_array.py -c -5 -f>")
        print("      <-r>       Rebuild the RAID array by replacing a failed disk with a new one")
        print("      <-m raid0 || raid5>  Convert RAID array to the given raid level")
        print("                           Support raid levels are raid0 and raid5")

    def validate_args(self, args, opts):

        if len(args) == 0:
            self.print_usage()
            sys.exit()

        if len(opts) == 0:
            self.print_usage()
            sys.exit()

        tmp = [ x for x in opts[0] if x != "" ]

        if len(args) != len(opts) and len(args) != len(tmp):
            self.print_usage()
            sys.exit()

    def run_configurator(self, argv, msg):
        """
        Method parses input arguments used for RAID array configuration.
        """

        try:
            opts, args = getopt.getopt(argv, self.cmd_options)
        except getopt.GetoptError:
            self.print_usage()
            sys.exit(2)

        self.validate_args(argv, opts)

        if opts == []:
            self.print_usage()
            sys.exit()

        # Discover block storages and MD raid
        #
        init_stor()
        discover_array()
        return self.process_cmd(opts, args, msg)

    def process_cmd(self, opts, args, msg):

        # RAID-0 is the default raid level
        raid_level = "raid0"
        active_count = 0
        cmd = None
        init_array = False
        ary_create_prompt = False
        force_ary_removal = False
        # Figure out our MD name dynamically. What we set in init() may not
        # be accurate
        self.calc_md_name()

        for opt, args in opts:
            if opt == "-i":
                self.manual_mount = True
                self.umount_and_remove_array(self.data_array_name, self.mount_unit_name)
                init_array = True
                force_ary_removal = True
                cmd = "create"
            elif opt == "-h":
                self.print_usage()
                sys.exit()
            elif opt == "-c":
                init_array = True
                ary_create_prompt = True
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                cmd = "create"
            elif opt == "-r":
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                cmd = "rebuild"
            elif opt == "-5":
                if cmd != "create":
                    self.print_usage()
                    sys.exit()
                self.enable_fscache = False
                raid_level = "raid5"
            elif opt == "-m":
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                if args == "raid5":
                    self.enable_fscache = False
                    raid_level = "raid5"
                elif args == "raid0":
                    raid_level = "raid0"
                else:
                    self.print_usage()
                    sys.exit()
                cmd = "convert"
            elif opt == "-f":
                if cmd != "create":
                    self.print_usage()
                    sys.exit()
                force_ary_removal = True
            else:
                self.print_usage()
                sys.exit()

        inactive_count = self.check_inactive_array(force_ary_removal)

        if inactive_count > 0:
            msg.append ("Cannot create array while inactive array exists")
            return MD_CONST.OP_ABORTED

        if cmd == "create":
            supported_raid_levels = ["raid0", "raid5"]
            for rlevel in supported_raid_levels:
                active_count = self.check_active_array(rlevel, force_ary_removal)

                # Can't create a new RAID-0 if an active one already exists
                if active_count > 0:
                    msg.append ("Cannot create array while active array exists")
                    return MD_CONST.OP_ABORTED

            return self.md_create_data_array(raid_level, init_array, ary_create_prompt, msg)
        elif cmd == "rebuild":
            return self.md_rebuild_data_array(msg)
        elif cmd == "convert":
            return self.md_convert_data_array(raid_level, msg)

    def md_zero_superblock(self, dev_name):
        """
        Zeroing the drive superblock and erase the partition on the drive
        """

        sps = SubPs()
        rc = sps.run_cmd([ "madm", "-E", dev_name ])

        # No superblock on drive
        if rc == 1:
            return

        print("zero-superblock " + dev_name)
        sps.run_cmd([ "sgdisk", "-Z", dev_name ])
        sps.run_cmd([ "mdadm",  "--zero-superblock", dev_name ])

    def check_active_array(self, raid_level, force_ary_removal):

        # Find arrays with active state. If force_ary_removal
        # is set, remove the inactive array.
        found_active = False
        active_list = []
        active_name = ""
        for raid in raid_list:
            if raid.raid_state == "active" and raid.raid_level == raid_level:
                active_list.append(raid)
                active_name += raid.raid_name + " "

        if len(active_list) > 0:
            print("Found active array(s) " + active_name)

            if force_ary_removal:
                # Try to stop NVSM if it is running
                self.safe_stop_nvsm()

                for raid in active_list:
                    sps = SubPs()
                    print("Removing active array " + raid.raid_name)
                    sps.run_cmd([ "mdadm", "--stop", "/dev/" + raid.raid_name ])

        reinit_stor()
        rediscover_array()

        active_list = []
        for raid in raid_list:
            if raid.raid_state == "active" and raid.raid_level == raid_level:
                active_list.append(raid)
                active_name += raid.raid_name + " "

        return len(active_list)

    def check_inactive_array(self, force_ary_removal):
        # Find arrays with inactive state (offline). If force_ary_removal
        # is set, remove the inactive array.
        found_inactive = False
        inactive_list = []
        inactive_name = ""
        for raid in raid_list:
            if raid.raid_state == "inactive":
                inactive_list.append(raid)
                inactive_name += raid.raid_name + " "

        if len(inactive_list) > 0:
            print("Found inactive array(s) " + inactive_name)

            if force_ary_removal:
                for raid in inactive_list:
                    sps = SubPs()
                    print("Removing inactive array " + raid.raid_name)
                    sps.run_cmd([ "mdadm", "--stop", "/dev/" + raid.raid_name ])

        reinit_stor()
        rediscover_array()

        inactive_list = []
        for raid in raid_list:
            if raid.raid_state == "inactive":
                inactive_list.append(raid)
                inactive_name += raid.raid_name + " "

        return len(inactive_list)

    def get_locked_drives(self, block_devs):
        #
        # Given a list of block devs, return list of devices that are locked
        #
        locked_drives = []
        for bd in block_devs:

            if bd.drive_locked:
                locked_drives.append(bd.get_dev_node_name())

        return locked_drives

    def md_create_data_array(self, raid_level, init_array, ary_create_prompt, msg, elig_devs=None):
        """
        Create data array with given raid_name and raid_level
        """

        block_devs = get_block_devs()

        locked_drives = self.get_locked_drives(block_devs)

        if len(locked_drives) > 0:
            msg.append("Operation aborted: one or more devices are locked.\n%s" %(" ".join(locked_drives)))

            return MD_CONST.OP_ABORTED

        print("Create array level " + raid_level)
        if is_supported_raid_level(raid_level) == False:
            msg.append("Invalid raid level " + raid_level)
            return MD_CONST.INVALID_RAID_LEVEL

        for raid in raid_list:
            if self.data_array_name == raid.get_raid_name():
                msg.append("RAID array " + self.data_array_name + " already exists")
                return MD_CONST.RAID_EXIST

        # Find devices that are eligible for data array creation.
        # Eligible devices are:
        # 1. Non-OS drive
        # 3. Drive not a member of another array group.
        raid_components = []
        if (elig_devs == None):
            self.get_eligible_bd_for_data_array(block_devs, raid_components)
        else:
            raid_components = elig_devs

        if is_enough_drives_for_raid(raid_level, raid_components) == False:
            msg.append("Not enough devices to create RAID array " + self.data_array_name + " to " + raid_level)
            return MD_CONST.NOT_ENOUGH_DRIVES

        drive_name_list = ""
        for bd in raid_components:
            drive_name_list += bd.get_dev_node_name() + " "

        if ary_create_prompt:
            print("Data on drives " + drive_name_list + " will be erased. Are you sure you want to continue? <y/n>")
            answer = None
            while True:
                answer = input().upper()
                if answer == "Y" or answer == "N":
                    break
                else:
                    print("Please answer (y)es or (n)o.")

            if answer == "N":
                msg.append("Operation aborted")
                self.safe_start_nvsm()
                return MD_CONST.OP_ABORTED

        # Try to stop NVSM if it is running
        self.safe_stop_nvsm()

        self.disable_raid_automount(self.data_array_name)

        cmd = "echo 'yes' | mdadm -C /dev/%s --name=nv-data-array -n %d --level=%s" %(self.data_array_name, len(raid_components), raid_level)

        # Zero the superblock and erase the partition on the drives
        # before creating a new array.
        for bd in raid_components:
            self.md_zero_superblock(bd.get_dev_node_name())
            cmd += " " + bd.get_dev_node_name()

        try:
            subprocess.check_output(cmd, shell=True, encoding='UTF-8')
        except subprocess.CalledProcessError as e:
            print("Unable to create %s array\n" %(self.data_array_name))
            print(e.output)
            return MD_CONST.CREATE_FAIL

        if init_array == True:
            rc = mkfs("/dev/" + self.data_array_name, "ext4")

            if rc == False:
                msg.append("Cannot create filesystem type ext4")
                return MD_CONST.FS_CREATE_FAIL

        if not os.path.exists(self.mount_point):
            os.makedirs(self.mount_point)

        self.update_md_conf(self.data_array_name)
        self.update_fstab(self.data_array_name, self.mount_point, "ext4")

        try:
            if self.manual_mount == True:
                subprocess.check_output("mount /dev/" + self.data_array_name + " " + self.mount_point, shell = True, encoding='UTF-8')
                subprocess.check_output("systemctl daemon-reload", shell = True, encoding='UTF-8')
            else:
                subprocess.check_output("systemctl daemon-reload", shell = True, encoding='UTF-8')
                subprocess.check_output("systemctl start " + self.mount_unit_name +".mount" , shell = True, encoding='UTF-8')

            if self.fscache_installed:
                if self.enable_fscache:
                    self.set_fscache("enable")
                    self.restart_fscache()
                else:
                    self.set_fscache("stop")
                    self.set_fscache("disable")

            self.safe_start_nvsm()
        except:
            print("Unable to mount /raid. cachefilesd not started")

        msg.append("RAID array " + self.data_array_name + " successfully created")
        return MD_CONST.SUCCESS

    def md_rebuild_data_array(self, msg):
        """
        Rebuild the failed array with given raid_name. Rebuilding the array will
        destroy the existing raid group and recreate a new array.
        """
        index = 0
        target_level = "raid0"

        # Umount the given array
        for raid in raid_list:
            if raid.raid_name == self.data_array_name:
                if raid.is_array_mounted():
                    raid.umount_array(self.mount_unit_name)
                    del raid_list[index]
                raid.stop_md()
                if raid.raid_level:
                    target_level = raid.raid_level
                break
            index += 1

        reinit_stor()

        # Try to assemble the array first. If assemble is successfull, the array has recovered with
        # its data intact. There is no need to recreate new array.
        assembled_success = True
        try:
            subprocess.check_output("mdadm /dev/" + self.data_array_name + " --assemble --scan --force", shell = True, encoding='UTF-8')
            reinit_stor()
            if target_level == "raid5" or target_level == "raid1":
                raid_comps = []
                block_devs = get_block_devs()
                self.get_eligible_bd_for_data_array(block_devs, raid_comps)

                raid_devs = self.get_raid_devices()
                if len(raid_comps) + len(raid_devs) < 3:
                    msg.append("Not enough devices to rebuild RAID-5")
                    return MD_CONST.NOT_ENOUGH_DRIVES

                for bd in raid_comps:
                    subprocess.check_output("mdadm /dev/" + self.data_array_name + " --add " + bd.get_dev_node_name(), shell = True, encoding='UTF-8')

        except subprocess.CalledProcessError as e:
            assembled_success = False
            print(e.output)

        rediscover_array()

        if assembled_success == True:
            index = 0
            for raid in raid_list:
                if raid.raid_name == self.data_array_name:
                    if raid.is_array_mounted() == False:
                        raid.mount_array(self.mount_unit_name)
                    break

            msg.append("RAID array " + self.data_array_name + " successfully rebuilt")
            return MD_CONST.SUCCESS
        else:
            return self.md_create_data_array(target_level, True, False, msg)

    def md_convert_data_array(self, raid_level, msg):
        """
        Convert existing data array to the given raid level
        """

        #
        # Support raid levels are raid0 and raid5
        #
        if is_supported_raid_level(raid_level) == False:
            msg.append("Unsupported raid level " + raid_level)
            return MD_CONST.INVALID_RAID_LEVEL

        target_raid = None
        index = 0

        # Find the raid object represents the given array
        for raid in raid_list:
            if self.data_array_name == raid.get_raid_name():
                target_raid = raid
                break
            index += 1

        if target_raid == None:
            msg.append("RAID " + self.data_array_name + " not found")
            return MD_CONST.RAID_NOT_FOUND

        # Make sure the new raid level is not the same as
        # exsiting array
        if (target_raid.get_raid_level() == raid_level):
            msg.append("RAID array " + self.data_array_name + " already at raid level " + raid_level)
            return MD_CONST.SUCCESS

        raid_components = [] # raid_components contains the block devices belong to the given array

        block_devs = get_block_devs()

        # Find block devices belong to the given array
        for bd in block_devs:
            if bd.is_raid_member():
                dev = bd.get_dev_node_name()[5:]
                if target_raid.is_dev_raid_member(dev):
                    bd.remove_raid_member()
                    raid_components.append(bd)

                continue

            if bd.is_boot_device():
                continue

            raid_components.append(bd)

        if is_enough_drives_for_raid(raid_level, raid_components) == False:
            msg.append("Not enough devices to convert RAID array " + self.data_array_name + " to " + raid_level)
            return MD_CONST.NOT_ENOUGH_DRIVES

        raid_level_name = "RAID-0"
        if raid_level == "raid5":
            raid_level_name = "RAID-5"
        elif raid_level == "raid1":
            raid_level_name = "RAID-1"

        print("Data on existing array will be lost after converting array to a new RAID level.")
        print("Are you sure you want to continue? <y/n>")

        answer = None
        while True:
            answer = input().upper()
            if answer == "Y" or answer == "N":
                break
            else:
                print("Please answer (y)es or (n)o.")

        if answer == "N":
            msg.append("Operation aborted")
            return MD_CONST.OP_ABORTED

        # These services to be stopped before umounting the array or else
        # the underlying devices might return busy status
        self.safe_stop_nvsm()

        if self.fscache_installed:
            self.set_fscache("stop")

        if target_raid.is_array_mounted():
            target_raid.umount_array(self.mount_unit_name)

        remove_array(self.data_array_name)

        del raid_list[index]

        return self.md_create_data_array(raid_level, True, False, msg)

    def get_raid_devices(self):
        rediscover_array()
        raid_devs = []
        for raid in raid_list:
            if self.data_array_name == raid.get_raid_name():
                raid_devs = raid.get_raid_devices()
                break

        return raid_devs

    def umount_and_remove_array(self, raid_name, mount_unit_name):
        index = 0
        for raid in raid_list:
            if raid.raid_name == raid_name:
                if raid.is_array_mounted():
                    raid.umount_array(mount_unit_name)
                    del raid_list[index]
                remove_array(raid_name)
                break
            index += 1

        reinit_stor()
        rediscover_array()

    def get_eligible_bd_for_data_array(self, block_devs, elig_dev):
        #
        # Return block devices eligible for raid-0 creation
        #
        for bd in block_devs:
            if bd.is_raid_member():
                continue
            if bd.is_boot_device():
                continue

            elig_dev.append(bd)

    def calc_md_name(self):

        if not raid_list:
            return

        data_drives = get_data_drives()
        if not data_drives:
            return

        #
        # iterate through mds, and find the one occupying the data drives
        #
        for raid in raid_list:
            raid_devs = sorted(raid.get_raid_devices())

            # If raid_devs is a subset of all the data_drives
            if set(raid_devs) <= set(data_drives):
                self.data_array_name = raid.raid_name
                return

        self.data_array_name = get_next_md_name()

class MDRaidArray(BlockDev):
    """
    Class implements MD array representation
    """
    def __init__(self, md_name, mdstat):
        BlockDev.__init__(self)
        self.raid_name = md_name
        self.raid_state = None
        self.raid_level = None
        self.raid_devices = []
        for i in mdstat:
            if self.raid_name in i:
                tmp = i.split(' ')
                del tmp[0]
                del tmp[0]
                self.init_raid(tmp)

    def init_block_dev(self, dev_name, df, vi):
        BlockDev.init_block_dev(self, dev_name, df, vi)

    def init_raid(self, mdstat):
        self.raid_state = mdstat[0]

        try:
            file = open("/sys/class/block/" + self.raid_name + "/md/level", "r")
            tmp = str(file.read())
            self.raid_level = tmp.rstrip()
            file.close()
        except IOError:
            self.raid_level = None

        issue_cmd("ls /sys/class/block/" + self.raid_name + "/slaves", self.raid_devices)

        dev_cnt = len(self.raid_devices)
        if dev_cnt == 0:
            return

        dev_cnt -= 1
        while dev_cnt > 0:
            if len(self.raid_devices[dev_cnt]) == 0:
                del self.raid_devices[dev_cnt]
            elif not os.path.exists("/dev/" + self.raid_devices[dev_cnt]):
                del self.raid_devices[dev_cnt]
            dev_cnt -= 1

    def is_array_mounted(self):
        tmp = []
        rc = issue_cmd("mount", tmp)

        if rc != 0:
            return False

        for i in tmp:
            if ("/dev/" + self.raid_name) in i:
                return True

        return False

    def mount_array(self, mount_unit_name):
        tmp = []
        try:
            cmd = "systemctl start " + mount_unit_name + ".mount --runtime"
            print(cmd)
            issue_cmd(cmd , tmp)
            issue_cmd("systemctl daemon-reload", tmp)
            return True
        except subprocess.CalledProcessError as e:
            print("Failed to mount filesystem at mount-point")
            return False

    def umount_array(self, mount_unit_name):
        tmp = []
        try:
            cmd = "systemctl stop " + mount_unit_name + ".mount --runtime"
            print(cmd)
            issue_cmd(cmd, tmp)
            return True
        except subprocess.CalledProcessError as e:
            print("Failed to umount RAID")
            return False

    def stop_md(self):
        tmp = []
        try:
            issue_cmd("mdadm --stop /dev/" + self.raid_name, tmp)
        except subprocess.CalledProcessError as e:
            print("Failed to stop md")
            return False

        return True

    def print_array(self):
        print((self.raid_name, self.raid_state, self.raid_level, self.raid_devices))

    def is_dev_raid_member(self, dev_name):
        for i in self.raid_devices:
            if dev_name in i:
                return True

        return False

    def get_raid_name(self):
        return self.raid_name

    def get_raid_state(self):
        return self.raid_state

    def get_raid_level(self):
        return self.raid_level

    def get_raid_devices(self):
        return self.raid_devices


def discover_array():
    mdstat = []
    rc = issue_cmd("cat /proc/mdstat", mdstat)

    if rc != 0:
        return

    for s in mdstat:
        if s == '':
            continue
        if "active" in s:
            str = s.split(' ')
            md = MDRaidArray(str[0], mdstat)
            raid_list.append(md)

def rediscover_array():
    i = len(raid_list) - 1
    while i >= 0:
        del raid_list[i]
        i -= 1

    discover_array()

def remove_array(raid_name):
    raid_dev_node = "/dev/" + raid_name
    sps = SubPs()
    sps.run_cmd( [ "mdadm",  "--stop", raid_dev_node ])

    # [bug 200459889]: Leave a little delay here for the array to full stop
    time.sleep(15)

def is_supported_raid_level(raid_level):
    if raid_level == "raid0" or raid_level == "raid1" or raid_level == "raid5":
        return True
    return False

def is_enough_drives_for_raid(raid_level, dev_list):
    max_drives = 3
    if raid_level == "raid0" or raid_level == "raid1":
        max_drives = 2

    if len(dev_list) < max_drives:
        return False

    return True

def get_next_md_name():
    pattern = re.compile("md[0-9]*[0-9]$")
    # grab a list of current md devices
    mdcurrent = [x for x in os.listdir("/dev/") if pattern.match(x)]
    if mdcurrent:
        for i in range(0, 127):
            mdn = "md" + str(i)
            # just pick the first available device number
            if mdn not in mdcurrent:
                return mdn
    # Return default of md0
    return "md0"

def get_data_drives():
    drive_list = []
    cmd = "/usr/local/sbin/nv_scripts/get_data_drives.bash"

    try:
        with open(os.devnull, 'w') as devnull:
            out = subprocess.check_output(cmd, shell=True, stderr=devnull, encoding='UTF-8')
    except:
        return True # return true because if M2 drive then skips

    drives = out.split(":")[1].split(',')

    for line in drives:
        drive_list.append(line.strip().split("/dev/")[1])

    return sorted(drive_list)

def test():
    discover_array()

    for raid in raid_list:
        raid.print_array()

    bd_list = get_block_devs()
    bd_list2 = []
    for bd in bd_list:
        if bd.has_partition() == False:
            bd_list2.append(bd)

