#!/usr/bin/python3

# Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# Provides a list of physical block devices installed in
# the system. A block device can be a SATA SSD or NVMe drive.

import subprocess
import string
import json
import os
from nv_subps import SubPs
from self_encrypt import Level0Discovery

block_devs = []

class BlockDev:
    """
    Abstract representation of a block device from
    which SATA SSD and NVMe block inherit.
    """

    def __init__(self):
        self.dev_node_name = None
        self.raid_member = False
        self.boot_device = False
        self.model = None
        self.serial = None
        self.fw_rev = None
        self.partitions = []
        self.node_name = None
        self.sed_supported = False
        self.locking_supported = False
        self.locking_enabled = False
        self.drive_locked = False

    def init_block_dev(self, node_name, df, vi):
        self.node_name = node_name
        self.dev_node_name = "/dev/" + node_name
        self.device_files = df
        self.model = vi["model"].strip()
        self.serial = vi["serial"].strip()
        self.fw_rev = vi["fw-rev"].strip()

    def check_self_encrypt_cap(self, l0_disc_buf):

        l0_disc = Level0Discovery()

        l0_disc.process_disc_resp_data(l0_disc_buf)

        if not l0_disc.is_sed_aware():
            return

        self.locking_supported = bool(l0_disc.locking_supported)
        self.locking_enabled = bool(l0_disc.locking_enabled)
        self.drive_locked = bool(l0_disc.locked)


    def get_disk_size(self):
        sps = SubPs()
        rc = sps.run_cmd( [ "cat", "/sys/class/block/" + self.node_name + "/size" ])

        if rc != 0:
            return 0

        disk_size = 0
        try:
            disk_size = int(sps.output)
        except:
            pass

        return disk_size

    def get_controller_name(self):
        return ""

    def get_model(self):
        return self.model

    def get_serial(self):
        return self.serial

    def get_device_files(self):
        return self.device_files

    def get_vendor_info(self, dev_name):
        raise NotImplementedError("Can't implement abstract method")

    def get_dev_node_name(self):
        return self.dev_node_name

    def is_boot_device(self):
        return self.boot_device

    def has_partition(self):
        return len(self.partitions) > 0

    def is_raid_member(self):
        return self.raid_member

    def remove_raid_member(self):
        self.raid_member = False

    def check_raid_member(self):
        """
        Parse /proc/mdstat to figure if this block device
        is a member of a RAID group.
        """
        mdstat = []
        rc = issue_cmd("cat /proc/mdstat", mdstat)

        if rc == 0:
            dev = self.dev_node_name[5:]
            for i in mdstat:
                if dev in i:
                    self.raid_member = True
                    return

    def add_partitions(self):
        out = []
        rc = issue_cmd("/bin/lsblk -l " + self.dev_node_name, out)

        if rc != 0:
            return

        count = 0
        name = self.dev_node_name[5:]
        for i in out:
            devs = i.split(' ')

            #
            # Look for device with partition.
            #
            if devs[0] != name and name in i:
                count += 1
                #
                # Check for /boot and /
                if ("boot" in i) or ("/" == i):
                    self.boot_device = True

                # Add partition
                self.partitions.append(devs[0])

class SCSIBlockDev(BlockDev):
    """
    SCSI block device representation
    """

    def __init__(self):
        BlockDev.__init__(self)

    def init_block_dev(self, node_name, df, vi):
        BlockDev.init_block_dev(self, node_name, df, vi)

    def get_vendor_info(self, dev_name):
        cmd = "cat /sys/class/block/" + dev_name + "/device/model"
        model = []
        vi = { "model" : '', "serial": '', 'fw-rev': '' }
        rc = issue_cmd(cmd, model)

        if rc == 0:
            vi["model"] = model[0]

        #
        # SCSI generic driver stores the drive's VPD Page 0x80 in sysfw.
        # The page contains the device serial number at offset 4. All we
        # need is fetch this page from sysfs to get the drive serial number
        vpd80_path = "/sys/class/block/" + dev_name + "/device/vpd_pg80"

        if not os.path.exists(vpd80_path):
            return vi

        try:
            sps = SubPs(None)
            rc = sps.run_cmd( [ "cat", vpd80_path ] )

            if rc != 0:
                return vi

            vpd80_data = sps.output.strip()
            vi["serial"] = vpd80_data[4:]

        except:
            pass

        return vi

class VirtIOBlockDev(BlockDev):
    """
    VirtIO block device representation
    """

    def __init__(self):
        BlockDev.__init__(self)

    def init_block_dev(self, node_name, df, vi):
        BlockDev.init_block_dev(self, node_name, df, vi)

    def get_vendor_info(self, dev_name):
        vi = { "model" : "VirtIO Disk",
               "serial": "NA",
               "fw-rev": "NA" }
        return vi

class NVMeBlockDev(BlockDev):

    SECURITY_RECV       = 0x82

    """
    NVMe block device representation
    """

    def __init__(self):
        BlockDev.__init__(self)
    def init_block_dev(self, node_name, df, vi):
        BlockDev.init_block_dev(self, node_name, df, vi)

        #
        # Send security receive command to determine if target supports
        # self-encryption.
        protocol = 1
        comID = 1
        cdw10 = protocol << 24 | comID << 8
        buflen = 800

        cmd = "nvme admin-passthru -o 0x%x -4 0x%x -5 %d -l %d %s -r -b" \
            %(NVMeBlockDev.SECURITY_RECV, cdw10, buflen, buflen, self.get_dev_node_name())

        sps = SubPs()
        rc = sps.run_cmd( cmd.split(' '))

        if rc == 0:
            self.check_self_encrypt_cap(sps.output.encode('utf-8'))

    def is_drive_unhealthy(self):
        #
        ctrl_name = "/dev/" + self.node_name
        # Check if the drive has any critical smart errors
        sps = SubPs()
        nvme_cmd = "nvme smart-log " + ctrl_name  + " -o json"
        rc = sps.run_cmd(nvme_cmd.split(' '))
        if rc == 0:
            # [bug 200698602]: "nvme smart-log" could return invalid
            # json output (nvme-cli-1.8.1-3.el7.x86_64).
            #
            # example:
            #
            #  "media_errors" : nan,
            #  "num_err_log_entries" : nan,
            try:
                val = json.loads(sps.output)
                # Add check to support multiple formats of critical_warning.
                # With nvme-cli >= 2.11, the data format for critical_warning
                # changed to have critical_warning as a nested json.
                # Using json.loads on the new structure results in
                # critical_warning as a python dict with 'value' as a key in
                # the dict. Prior to 2.11, the data format of critical_warning
                # was a simple json key/value and using json.loads resulted
                # in the same.
                if (isinstance(val["critical_warning"], int) and
                        val["critical_warning"] != 0) or \
                    (isinstance(val["critical_warning"], dict) and
                        val["critical_warning"].get("value", 0) != 0):
                    print("Drive excluded :" + ctrl_name)
                    return True
            except:
                #print("Invalid json output on %s" % ctrl_name)
                return False
        return False

    def get_controller_name(self):
        ''' Return nvme controler name '''

        ctrl_name = self.node_name.replace("n1", "")

        return ctrl_name

    def get_vendor_info(self, dev_name):
        cmd = "cat /sys/class/block/" + dev_name + "/device/model"
        model = []
        issue_cmd(cmd, model)

        cmd = "cat /sys/class/block/" + dev_name + "/device/serial"
        serial = []
        issue_cmd(cmd, serial)

        vi = { "model" : '', "serial": '', 'fw-rev': '' }
        vi["model"] = model[0].strip()
        vi["serial"] = serial[0].strip()

        return vi

def create_block_dev(dev_name, df):
    """
    Create block device object and add it to the global
    list, block_devs
    """
    bd = None
    if "nvme" in dev_name:
        bd = NVMeBlockDev()
    elif "sd" in dev_name:
        bd = SCSIBlockDev()
    elif "vd" in dev_name:
        bd = VirtIOBlockDev()
    else:
        # Return if we encounter unrecognized device types. Otherwise we will
        # backtrace when calling get_vendor_info() below.
        return None

    vi = bd.get_vendor_info(dev_name)

    if bd != None:
        bd.init_block_dev(dev_name, df, vi)
        bd.add_partitions()
        bd.check_raid_member()
        block_devs.append(bd)

    return bd

def print_block_devs():
    for bd in block_devs:
        model = bd.get_model()
        device = bd.get_dev_node_name()
        print("Model: ", model, "  Device: ", device, "Has partition:", bd.has_partition(), "  Boot: ", bd.is_boot_device(), "raid member", bd.is_raid_member())

def get_block_devs():
    return block_devs

def init_stor():
    """
    discover all block devices in the system.
    """
    bds = []
    enum_block_devs(bds)

    for s in bds:
        df = { "by-id" : '', "by-wwn" : '', "by-path" : '' }
        get_device_files(s, df)
        create_block_dev(s, df)

def reinit_stor():
    i = len(block_devs) - 1
    while i >= 0:
        del block_devs[i]
        i -= 1

    init_stor()

def mkfs(dev_node, fs_type):
    ''' Support only ext4 for now. This will change if we support other fstypes. '''
    cmd = []
    if fs_type == "ext4":
        cmd = [ "mkfs.ext4", "-F" , dev_node ]
    else:
        return False

    rc = True
    sps = SubPs()
    ret = sps.run_cmd(cmd)

    rc = True
    if 0 != ret:
        rc = False

    return rc

def issue_cmd(cmd, output):
    #
    # Issues system command and returns data in output argument
    sps = SubPs()

    rc = sps.run_cmd(cmd.split(' '))

    if rc != 0:
        return rc

    tmp = sps.output.split('\n')

    if output == None:
        return 0

    for line in tmp:
        output.append(line)

    return 0

def find_device_attr(dev_name, by_attr_name):
    #
    # Find device is attribute name which is either device's id or path id

    full_path = "/dev/disk/" + by_attr_name

    done = False
    for _, _, attr_list in os.walk(full_path):
        for attr in attr_list:
            #
            # disk id is just a symlink to the device 's node name
            real_path = os.path.realpath(full_path + "/" + attr)
            if real_path == dev_name:
                return attr

    return ""

def get_device_files(dev_name, df):
    #
    # Find the device's id
    df["by-id"] = find_device_attr("/dev/" + dev_name, "by-id")

    #
    # Find the device path
    df["by-path"] = find_device_attr("/dev/" + dev_name, "by-path")

def is_removable_device(dev_node_name):
    flag = -1
    try:
        file = open("/sys/class/block/" + dev_node_name + "/removable", "r")
        flag = int(file.read())
        file.close()
    except IOError:
        return -1

    return flag

def is_usb_device(dev_transport):
    return (dev_transport == "usb")

def enum_block_devs(block_devs):
    cmd = "/bin/lsblk -do name,type,tran"

    sps = SubPs()
    rc = sps.run_cmd(cmd.split(' '))

    if rc != 0:
        return

    for i in sps.output.split('\n'):
        if not "disk" in i:
            continue

        fields = i.split()

        dev_name = fields[0]
        dev_type = fields[1]

        dev_transport = "unknown"
        if len(fields) > 2:
            dev_transport = fields[2]

        if (is_usb_device(dev_transport) or
            is_removable_device(dev_name)):
            if not is_virtual_box():
                continue

        block_devs.append(dev_name)

def get_partuuid_from_device(device_name):
    # Ensure device_name starts with /dev/ if it doesn't already
    if not device_name.startswith('/dev/'):
        device_name = '/dev/' + device_name

    # Use blkid to get the PARTUUID
    cmd = "blkid -s PARTUUID -o value " + device_name
    output = []
    rc = issue_cmd(cmd, output)

    if rc != 0 or not output:
        return ""

    # Return the first line of output, stripped of whitespace
    partuuid = output[0].strip()
    return partuuid

def is_virtual_box():
    output = []
    rc = issue_cmd("dmidecode --type 1", output)

    if rc != 0:
        return False

    manfacturer = None
    product_name = None

    for i in output:
        tmp = i.split(':')
        if "Manufacturer:" in i:
            manufacturer = tmp[1].strip()
            continue
        if "Product Name:" in i:
            product_name = tmp[1].strip()
            break

    if product_name == "VirtualBox":
        return True

    return False


def test():
    init_stor()
    print_block_devs()

    # Test the get_partuuid_from_device function
    print("\nTesting PARTUUID retrieval:")
    for bd in block_devs:
        if bd.has_partition():
            for partition in bd.partitions:
                partuuid = get_partuuid_from_device(partition)
                print(f"Device: {partition}, PARTUUID: {partuuid}")
        else:
            # Test with the main device
            partuuid = get_partuuid_from_device(bd.get_dev_node_name())
            print(f"Device: {bd.get_dev_node_name()}, PARTUUID: {partuuid}")

