#!/usr/bin/python3

# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

#
# Module implements platform abstraction using manifest file. The abstraction
# uses OS drive to determine the platform type, as opposed to product name which
# could change.
#
# The manifest contains entries for each supported platform. Each entry is called
# platform profile which serves as a class template that uses Python introspection
# to discover the platform type.
#
# A platform profile contains 2 introspection fields:
#
# DiscoveryMethod - method checking the boot drive to determine the platform type.
# PlatClass       - A python class for a specific platform. The class is instanstiated
#                   once platform type has been determined by DiscoveryMethod.
#
# Qualifier for DGX-1 platform:
#
# 1. Boot drive is sda
# 2. Drive vendor is AVAGO
#
# Qualifier for DGX Station
#
# 1. Boot drive is sda
# 2. Drive vendor is ATA
#
# Qualifier for Explorer
#
# 1. Boot drive is NVMe
# 2. Is there need to check for M.2 ?
#
import json
import os
import fileinput
import importlib
import subprocess
from nv_subps import SubPs


class PlatManifest:
    def __init__(self, manifest_path):
        self.plat_manifest = json.loads(open(manifest_path).read())
        self.sys_prod_name = self.get_system_product_name()
        self.plat_short = self.get_plat_short()

    def get_platform_profile(self):
        # Scan the manifest and returns the profile for current running
        # platform

        if "PlatformProfiles" not in self.plat_manifest:
            return None

        profile_list = self.plat_manifest["PlatformProfiles"]

        # Run preferred platform discovery methods first...
        for profile in profile_list:
            method = None

            try:
                discovery_method_name = profile["DiscoveryMethod"]
                method = getattr(self, discovery_method_name)
            except (AttributeError, KeyError) as e:
                continue

            profile_matched = method(profile)
            if profile_matched:
                return self.platform_constructor(profile)

        # ... if the preferred methods do not work, try the alternate
        for profile in profile_list:
            method = None

            try:
                discovery_method_name = profile["AltDiscoveryMethod"]
                method = getattr(self, discovery_method_name)
            except (AttributeError, KeyError) as e:
                continue

            profile_matched = method(profile)
            if profile_matched:
                return self.platform_constructor(profile)

        return None

    def platform_constructor(self, profile):
        # Instantiate platform instance of current running
        # platform
        module = importlib.import_module("dgx_platform")
        plat_class = getattr(module, profile["PlatClass"])
        plat_instance = plat_class()
        return plat_instance

    def get_boot_device(self):
        # Returns the boot device which has the /boot/efi mounted
        # on the root file system.
        boot_part = "/boot/efi"

        sps = SubPs()
        sps.run_cmd([ "findmnt", boot_part ])

        for line in sps.output.split('\n'):
            if not (boot_part in line):
                continue

            fields = line.split(' ')
            return fields[1]

        return ""

    def get_plat_short(self):
        cmd = ". %s && %s \"%s\"" % \
            ("/usr/local/sbin/nv_scripts/plat_funcs.bash", "get_platform_short",
             self.sys_prod_name)
        (status, ret) = subprocess.getstatusoutput(cmd)

        if status == 0:
            return ret

        return "other"

    def get_system_product_name(self):
        cmd = "dmidecode --string system-product-name"
        sps = SubPs()

        try:
            sps.run_cmd(cmd.split())
            return sps.output
        except:
            return ""

    def is_dgx_1(self, profile):
        if self.plat_short == "dgx1":
            return True

        return False

    def is_explorer(self, profile):
        if self.plat_short == "dgx2":
            return True

        return False

    def is_dcs(self, profile):
        if self.plat_short == "dcs" or self.plat_short == "dcs_legacy":
            return True

        return False

    def is_dcs_storage(self, profile):
        return self.is_dcs(profile)

    def is_dgx_a100(self, profile):
        # Several platforms use the same RAID configuration here
        cmd = ". %s && %s" % \
            ("/usr/local/sbin/nv_scripts/plat_funcs.bash", "plat_uses_dgx_a100_raid_config")
        (status, ret) = subprocess.getstatusoutput(cmd)

        if status == 0:
            return True

        return False

    def is_dgx_a100_storage(self, profile):
        return self.is_dgx_a100(profile)

    def is_dgx_station(self, profile):
        if self.plat_short == "dgxstation":
            return True

        return False

    def is_dgx_station2(self, profile):
        # Several platforms use the same RAID configuration here
        cmd = ". %s && %s" % \
            ("/usr/local/sbin/nv_scripts/plat_funcs.bash", "plat_uses_dgxstation_a100_raid_config")
        (status, ret) = subprocess.getstatusoutput(cmd)

        if status == 0:
            return True

        return False

    def is_dgx_station2_storage(self, profile):
        return self.is_dgx_station2(profile)

    def is_dgx_1_storage(self, profile):
        boot_device = self.get_boot_device()

        if not (profile["BootDevice"] in boot_device):
            return False

        vendor_dev_file = "/sys/class/block/" + profile["BootDevice"] + "/device/vendor"

        found = False
        for line in fileinput.FileInput(vendor_dev_file):
            if ("AVAGO" in line) or ("LSI" in line):
                found = True
                break

        fileinput.close()

        return found

    def is_explorer_storage(self, profile):
        boot_device = self.get_boot_device()

        if not ("nvme" in boot_device):
            return False

        return True

    def is_dgx_station_storage(self, profile):
        boot_device = self.get_boot_device()

        if not (profile["BootDevice"] in boot_device):
            return False

        vendor_dev_file = "/sys/class/block/" + profile["BootDevice"] + "/device/vendor"

        found = False
        for line in fileinput.FileInput(vendor_dev_file):
            if "ATA" in line:
                found = True
                break

        fileinput.close()

        return found
