#!/usr/bin/python3

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

#
# RAID configurator for DGXStation2 software raid
#

import sys
import os
import md_raid_configurator
from nv_subps import SubPs
from md_raid_configurator import MDRaidConfigurator
from md_raid_configurator import MD_CONST
from md_raid_configurator import get_next_md_name

class DGXStation2MDRaidConfigurator(MDRaidConfigurator):
    def __init__(self, script_name):
        self.RAIDLVLMAX = 1
        self.RAIDLVLSTR = "raid" + str(self.RAIDLVLMAX)

        MDRaidConfigurator.__init__(self, script_name)
        self.data_array_name = "md0"
        self.manual_mount = False
        self.cmd_options = "ihcrf" + str(self.RAIDLVLMAX) + "m:"
        self.enable_fscache = True

    def print_usage(self):
        ''' How to use this script'''
        print("\nUsage :\n")
        print("    <script_name> <options>\n")
        print("    <options>:\n")
        print("      <-h>    Displays help")
        print("      <-c>    If there are multiple eligible devices, create a RAID")
        print("              volume. Defaults to raid0.")
        print("              If there is only one eligible device, create a single")
        print("              single device volume")
        print("              <-" + str(self.RAIDLVLMAX) + ">  create " + self.RAIDLVLSTR + " instead of the default raid0. This option")
        print("                    is ignored in the case of a single device volume")
        print("                    <configure_raid_array.py -c -" + str(self.RAIDLVLMAX) + ">")
        print("              <-f>  Use the -f option with the -c to force the creation of")
        print("                    the volume.")
        print("                    <configure_raid_array.py -c -f>")
        print("                    <configure_raid_array.py -c -" + str(self.RAIDLVLMAX) + " -f>")
        print("      <-r>    Rebuild the RAID array by replacing a failed disk with a new one.")
        print("              No-op in the case of a single device volume.")
        print("      <-m raid0 || " + self.RAIDLVLSTR  + ">  Convert RAID array to the given raid level")
        print("                           Supported raid levels are raid0 and " + self.RAIDLVLSTR)
        print("                           No-op in the case of a single device volume.")

    def process_cmd(self, opts, args, msg):
        # RAID-0 is the default raid level
        raid_level = "raid0"
        active_count = 0
        cmd = None
        init_array = False
        ary_create_prompt = False
        force_ary_removal = False
        # Figure out our MD name dynamically. What we set in init() may not
        # be accurate
        self.calc_md_name()

        for opt, args in opts:
            if opt == "-i":
                self.manual_mount = True
                self.umount_and_remove_array(self.data_array_name, self.mount_unit_name)
                init_array = True
                force_ary_removal = True
                cmd = "create"
            elif opt == "-h":
                self.print_usage()
                sys.exit()
            elif opt == "-c":
                init_array = True
                ary_create_prompt = True
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                cmd = "create"
            elif opt == "-r":
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                cmd = "rebuild"
            elif opt == "-" + str(self.RAIDLVLMAX):
                if cmd != "create":
                    self.print_usage()
                    sys.exit()
                self.enable_fscache = False
                raid_level = self.RAIDLVLSTR
            elif opt == "-m":
                if cmd != None:
                    self.print_usage()
                    sys.exit()
                if args == self.RAIDLVLSTR:
                    self.enable_fscache = False
                    raid_level = self.RAIDLVLSTR
                elif args == "raid0":
                    raid_level = "raid0"
                else:
                    self.print_usage()
                    sys.exit()
                cmd = "convert"
            elif opt == "-f":
                if cmd != "create":
                    self.print_usage()
                    sys.exit()
                force_ary_removal = True
            else:
                self.print_usage()
                sys.exit()

        inactive_count = self.check_inactive_array(force_ary_removal)

        if inactive_count > 0:
            msg.append ("Cannot create array while inactive array exists")
            return MD_CONST.OP_ABORTED

        if cmd == "create":
            supported_raid_levels = ["raid0", self.RAIDLVLSTR]
            for rlevel in supported_raid_levels:
                active_count = self.check_active_array(rlevel, force_ary_removal)

                # Can't create a new RAID-0 if an active one already exists
                if active_count > 0:
                    msg.append ("Cannot create array while active array exists")
                    return MD_CONST.OP_ABORTED

            return self.create_data_volume(raid_level, init_array, force_ary_removal, ary_create_prompt, msg)
        elif cmd == "rebuild":
            return self.md_rebuild_data_array(msg)
        elif cmd == "convert":
            return self.md_convert_data_array(raid_level, msg)

    def is_device_mounted(self, unit_name):
        sps = SubPs()
        rc = sps.run_cmd( [ "mountpoint", "/" + unit_name ] )

        if rc == 0:
            return True

        return False

    def umount_device(self, unit_name):
        sps = SubPs()
        rc = sps.run_cmd( [ "systemctl", "stop", unit_name + ".mount" ] )
        return rc, sps.os_err_str

    def mount_device(self, dev_name, unit_name):
        sps = SubPs()
        raid_dir = "/" + unit_name

        # Make sure raid directory exists
        if not os.path.exists(raid_dir):
            sps.run_cmd( [ "mkdir", "-p", raid_dir ] )

        # mount unit for /raid may not exist in systemd. So use
        # old-fashion mount command
        rc = sps.run_cmd( [ "mount", dev_name, raid_dir ] )
        return rc, sps.os_err_str

    def create_single_drive_data_volume(self, block_dev, force_option, ary_create_prompt, msg):
        # If force option is not given, don't format if the drive already has partition on it
        if not force_option:
            if block_dev.has_partition():
                msg.append("Partition exists on device " + block_dev.get_dev_node_name())
                return MD_CONST.OP_ABORTED

        drive_name = block_dev.get_dev_node_name()

        # User manually creates the array (not at installation time)
        if ary_create_prompt:
            answer = None
            if not force_option:
                print("Data on drive " + drive_name + " will be formatted. Are you sure you want to continue? <y/n>")
                while True:
                    answer = input().upper()
                    if answer == "Y" or answer == "N":
                        break
                    else:
                        print("Please answer (y)es or (n)o.")

            if answer == "N":
                msg.append("Operation aborted")
                return MD_CONST.OP_ABORTED

        self.disable_raid_automount(block_dev.node_name + "p1")

        # Steps to setup data volume:
        #
        # 1. Stop NVSM if it's running
        # 2. Umount /raid
        # 3. Wipe clean the drive
        # 4. Create partition, align the partition to offset 2048
        # 5. mkfs new partition
        # 6. Mount new partition on /raid
        # 7. Update fstab with new partition entry
        # 8. Start cachefilesd on /raid
        # 9. Restart NVSM if it was previously running

        # Step 1
        self.safe_stop_nvsm()

        # Step 2
        if self.is_device_mounted(self.mount_unit_name):
            print("Umounting " + drive_name)
            rc, err_msg = self.umount_device(self.mount_unit_name)

            if rc != 0:
                msg.append("Cannot umount /raid: " + err_msg)
                return MD_CONST.OP_ABORTED

        # Step 3
        sps = SubPs()
        sps.run_cmd([ "sgdisk", "-Z", drive_name ])

        # Step 4
        disk_size = block_dev.get_disk_size() - 2048

        rc = sps.run_cmd(["sgdisk", "--clear", "--recompute-chs", "--new", "1:2048:" + str(disk_size),
                           "--change-name", "1:\"Linux filesystem\"", drive_name])

        if rc != 0:
            msg.append("Error partitioning device " + drive_name + ": " + sps.os_err_str)
            return MD_CONST.OP_ABORTED

        # Step 5
        if "nvme" in drive_name:
            part_device = drive_name + "p1"
        else:
            part_device = drive_name + "1"

        print("Formatting device " + part_device)
        rc = self.stordev_module.mkfs(part_device, "ext4")

        if not rc:
            msg.append("Cannot create filesystem type ext4")
            return MD_CONST.FS_CREATE_FAIL

        # Step 6
        print("Mounting device " + part_device + " on /" + self.mount_unit_name)
        rc, err_msg = self.mount_device(part_device, self.mount_unit_name)

        if rc != 0:
            msg.append("Cannot mount raid: " + err_msg)
            return MD_CONST.FS_CREATE_FAIL

        # Step 7
        # Get PARTUUID for the partition and use it in fstab for more reliable mounting
        partuuid = self.stordev_module.get_partuuid_from_device(part_device)
        if partuuid:
            # Use /dev/disk/by-partuuid/ format for fstab entry
            fstab_device = "disk/by-partuuid/" + partuuid
            print("Adding /dev/disk/by-partuuid/" + partuuid + " entry to fstab")
        else:
            # Fallback to device name if PARTUUID is not available
            fstab_device = part_device
        self.update_fstab(fstab_device, "/" + self.mount_unit_name, "ext4")

        # Step 8
        if self.fscache_installed:
            if self.enable_fscache:
                print("Starting cachefilesd...")
                self.set_fscache("enable")

                if not self.restart_fscache():
                    msg.append("Cannot restart cachefilesd service")
            else:
                self.set_fscache("stop")
                self.set_fscache("disable")

        # Step 9
        self.safe_start_nvsm()

        return MD_CONST.SUCCESS

    def get_elig_devs(self):
        self.stordev_module.reinit_stor()
        block_devs = self.stordev_module.get_block_devs()
        elig_devs = []
        self.get_eligible_bd_for_data_array(block_devs, elig_devs)
        return elig_devs

    def create_data_volume(self, raid_level, init_array, force_option, ary_create_prompt, msg):
        # Create data volume using U.2 drives
        #
        # If there is only 1 drive, just format it
        # If there are more than one, then create a RAID
        elig_devs = self.get_elig_devs()
        print("Available data drives: %d" % (len(elig_devs)))

        if len(elig_devs) == 0:
            msg.append("Not enough devices to create data volume")
            return MD_CONST.NOT_ENOUGH_DRIVES
        elif len(elig_devs) == 1:
            # Default to not enable cachefilesd on a single disk volume
            self.enable_fscache = False
            return self.create_single_drive_data_volume(elig_devs[0], force_option, ary_create_prompt, msg)
        else:
            return self.md_create_data_array(raid_level, init_array, ary_create_prompt, msg)
