#!/usr/bin/python3

# Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import psutil
import subprocess
import fileinput
import importlib
import inspect
import re
from nv_subps import SubPs

class RaidConfigurator:

    def __init__(self, configurator_name, script_name):
        self.configurator_name = configurator_name
        self.data_array_name = None
        self.script_name = script_name
        self.mount_point = "/raid"
        self.mount_unit_name = "raid"
        self.enable_fscache = False
        self.stordev_module = importlib.import_module("stordev")
        self.fscache_installed = self.get_service_installed("cachefilesd.service")
        self.nvsm_installed = self.get_service_installed("nvsm.service")
        self.nvsm_was_active = self.get_service_active("nvsm.service")
        self.nvsm_is_parent = self.get_nvsm_is_parent()

    def run_configurator(self, argvs, msg):
        raise NotImplementedError("Can't implement abstract method")

    def get_cachefilesd_conf(self):
        if self.is_redhat_like():
            return "/etc/cachefilesd.conf"

        return "/etc/default/cachefilesd"

    def get_service_installed(self, service_name):
        cmd = ("systemctl list-unit-files --type service " + service_name +
               " | wc -l")

        #
        # Successful output:
        #   UNIT FILE           STATE
        #   cachefilesd.service enabled
        #
        #   1 unit files listed.
        #
        # Failure output:
        #   UNIT FILE STATE
        #
        #   0 unit files listed.
        #
        # Both cases return 0

        (status, output) = subprocess.getstatusoutput(cmd)
        return (int(output) > 3)

    def get_service_active(self, service_name):
        cmd = ("systemctl is-active " + service_name)
        (status, output) = subprocess.getstatusoutput(cmd)
        return (output == "active")

    def is_redhat_like(self):
        sps = SubPs()
        rc =  sps.run_cmd( [ "cat", "/etc/os-release" ])

        if rc != 0:
            return False

        for line in sps.output.split('\n'):
            if ((("Red Hat" in line) or ("CentOS" in line) or ("Fedora" in line)) or
                (line.startswith("ID_LIKE=") and ("rhel" in line or "centos" in line or "fedora" in line))):
                return True

        return False

    def update_md_conf(self, raid_name):
        tmp = subprocess.check_output("mdadm --detail --scan", shell=True, encoding='UTF-8').split('\n')

        if len(tmp) == 0:
            return

        i = 0
        found = False
        while i < len(tmp):
            if raid_name in tmp[i]:
                found = True
                break
            i += 1

        if not found:
            return

        #
        # Open input file stream for mdadm.conf. Replace the array id
        # with id returned by mdadm --detail --scan.
        #

        mdadm_conf = "/etc/mdadm/mdadm.conf"
        initram_fs = "/usr/sbin/update-initramfs -u"
        if self.is_redhat_like():
            mdadm_conf = "/etc/mdadm.conf"

            sps = SubPs()
            rc = sps.run_cmd( [ "uname", "-r" ])
            if rc != 0:
                return

            uname = sps.output.rstrip("\n")
            initram_fs = "dracut --force /boot/initramfs-" + uname + ".img " + uname

        # It's possible that mdadm.conf to not exist so create it here.
        if not os.path.exists(mdadm_conf):
            with open(mdadm_conf, 'a+') as mdout:
                mdout.writelines(["#\n",
                                  "# This file was auto-created by raid_configurator.py\n",
                                  "#\n",
                                  "MAILADDR root\n"])

        #
        # Replace ARRAY entry if exists in mdadm_conf
        found = False
        for line in fileinput.FileInput(mdadm_conf, inplace = True):
            if raid_name in line:
                print(tmp[i] + "\n")
                found = True
            else:
                print(line.rstrip())

        fileinput.close()

        #
        # ARRAY entry does not exist in mdadm_conf
        if not found:
            with open(mdadm_conf, 'a+') as f:
                f.write(tmp[i] + "\n")
            f.close()

        print(initram_fs)

        sps = SubPs()
        rc = sps.run_cmd( initram_fs.split(' '))

        if rc != 0:
            print("Error updating initramfs")

    #
    # get_filestem_uuid, update_fstab_future, disable_raid_automount:
    #
    # In the future, the above 3 functions will be used to update the fstab with device's file
    # system UUID

    def get_filesystem_uuid(self, dev_name):

        #
        # This will get the filesystem UUID for the correspondent device
        cmd = "blkid -o value " + dev_name
        sps = SubPs()
        rc = sps.run_cmd(cmd.split(' '))

        if rc != 0:
            return False, ""

        try:
            for line in sps.output.split('\n'):
                uuid = line
                # Verify the uuid points to our device
                real_path = os.path.realpath("/dev/disk/by-uuid/" + uuid)
                if real_path == dev_name:
                    return True, uuid
        except:
            pass

        return False, ""

    def update_fstab_future(self, raid_name, mount_point, fstype):
        fstab = "/etc/fstab"
        #
        # REVISIT:
        #
        # Need to check what should be the fsck order for RAID-0 partition (last parameter on fstab entry
        #
        bRet, uuid = self.get_filesystem_uuid("/dev/" + raid_name)

        if bRet:
            fstab_dev = "UUID=" + uuid
        else:
            fstab_dev = "/dev/" + raid_name

        fstab_entry = "%s %s %s defaults,nofail,discard 0 0" %(fstab_dev, mount_point, fstype)

        found = False

        for line in fileinput.FileInput(fstab, inplace = True):
            #
            # Check the input to see if entry for raid is already in
            # fstab. If entry exists, remove it
            #
            commented_line = (len(line) > 0) and (line[0] == '#')
            if not ("/raid" in line) and (not commented_line): \
                print(line.rstrip())

        fileinput.close()

        #
        # Append to the fstab if the given raid does not have an entry in the file.
        #
        if not found:
            file = open(fstab, "a")
            file.seek(0, 2)
            file.write(fstab_entry + "\n")
            file.close()


    def disable_raid_automount_future(self, raid_name):

        bRet, uuid = self.get_filesystem_uuid("/dev/" + raid_name)

        if bRet:
            fstab_dev = "UUID=" + uuid
        else:
            fstab_dev = "/dev/" + raid_name

        fstab = "/etc/fstab"

        for line in fileinput.FileInput(fstab, inplace = True):
            if fstab_dev in line:
                print("#" + line.rstrip())
            else:
                print(line.rstrip())

        fileinput.close()

        subprocess.check_output("systemctl daemon-reload", shell = True, encoding='UTF-8')

    def is_comment(self, line):
        return (len(line) > 0) and (re.search(r'^[ \t]*#', line))

    def update_fstab(self, raid_name, mount_point, fstype):
        fstab = "/etc/fstab"
        #
        # REVISIT:
        #
        # Need to check what should be the fsck order for RAID-0 partition (last parameter on fstab entry
        #
        str = "/dev/" + raid_name + " " + mount_point + " " + fstype + " defaults,nofail,discard 0 0"
        str_regex = "[ \t]*\/dev\/" + raid_name + "[ \t]+\/raid\/*[ \t]+.+[ \t]+.+[ \t]+.+[ \t]+.+"
        found = False

        for line in fileinput.FileInput(fstab, inplace = True):
            #
            # Check the input to see if entry for raid is already in
            # fstab. If entry exists but commented out, remove the comment
            #
            if re.search((r'%s' % str_regex), line):
                if self.is_comment(line):
                    print(str)
                else:
                    print(line.rstrip())
                found = True
            else:
                print(line.rstrip())

        fileinput.close()

        #
        # Append to the fstab if the given raid does not have an entry in the file.
        #
        if not found:
            file = open(fstab, "a")
            file.seek(0, 2)
            file.write(str + "\n")
            file.close()

        #
        # Comment out md entries that are not used
        for line in fileinput.FileInput(fstab, inplace = True):
            tmp = line.strip().split(' ')
            if len(tmp) < 2:
                print (line.rstrip())
                continue

            #
            # Not raid entry
            if tmp[1] != "/raid":
                print (line.rstrip())
                continue

            #
            # Entry already commented out or contains md name just created
            if self.is_comment(line) or tmp[0] == "/dev/" + raid_name:
                print (line.rstrip())
                continue

            print ("#" + line.rstrip())

        fileinput.close()

    def disable_raid_automount(self, raid_name):
        fstab = "/etc/fstab"
        str = "/dev/" + raid_name

        for line in fileinput.FileInput(fstab, inplace = True):
            #
            # Check the input to see if entry for raid is already in
            # fstab. If entry exists but commented out, remove the comment
            #
            if ("/dev/" + raid_name) in line and not self.is_comment(line):
                print("#" + line.rstrip())
            else:
                print(line.rstrip())

        fileinput.close()

        subprocess.check_output("systemctl daemon-reload", shell = True, encoding='UTF-8')

    def restart_fscache(self):
        ''' Set the RUN flag in the cache file and restart fs cache'''

        cachefile = self.get_cachefilesd_conf()

        try:
            for line in fileinput.FileInput(cachefile, inplace = True):
                if "#RUN=yes" in line:
                    print("RUN=yes")
                else:
                    print(line.rstrip())

            fileinput.close()
        except:
            # This file doesn't exist for CentOS/RHEL
            print("Could not open %s" % cachefile)

        try:
            subprocess.check_output("systemctl restart cachefilesd", shell = True, encoding='UTF-8')
        except:
            return False

        return True

    def get_nvsm_is_parent(self):
        found_nvsm = False
        my_pid = os.getpid()
        parent_pid = -1

        while parent_pid != str(1):
            try:
                with open("/proc/%s/stat" % my_pid, 'r') as procfile:
                    proc_info = procfile.read()
                    proc_name = proc_info.split()[1]
                    parent_pid = proc_info.split()[3]

                    if "nvsm" in proc_name:
                        found_nvsm = True
                        break

                    my_pid = parent_pid
            except:
                print("WARNING: error iterating through process parents")
                break

        return found_nvsm

    def safe_stop_nvsm(self):
        if self.nvsm_installed and \
           self.nvsm_was_active and \
           not self.nvsm_is_parent:
            self.set_nvsm("stop")

    def safe_start_nvsm(self):
        if self.nvsm_installed and \
           self.nvsm_was_active and \
           not self.nvsm_is_parent:
            self.set_nvsm("start")

    def set_fscache(self, action):
        return self.set_service(action, "cachefilesd")

    def set_nvsm(self, action):
        return self.set_service(action, "nvsm")

    def set_service(self, action, service_name):
        try:
            subprocess.check_output("systemctl " + action + " " + service_name, shell = True, encoding='UTF-8')
        except:
            return False

        return True

    def print_usage(self):
        raise NotImplementedError("Can't implement abstract method")
