# -*- coding: utf-8; Mode: Python; indent-tabs-mode: nil; tab-width: 4 -*-
#
# «crypt-passwd» - Add new encryption key, and get rid of the default one
#
# Copyright (C) 2020 NVIDIA Corp
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

from __future__ import print_function

import os
import subprocess
import re
import debconf
import sys

from ubiquity import misc, plugin, validation


NAME = 'nv-crypt-passwd'
AFTER = 'nv-config-bmc'
WEIGHT = 20

def FrontendIsUbiquity():
    return 'UBIQUITY_FRONTEND' in os.environ and os.environ['UBIQUITY_FRONTEND'] == 'debconf_ui'

def make_error_string(controller, errors):
    """Returns a newline-separated string of translated error reasons."""
    return "\n".join([controller.get_string(error) for error in errors])

def setCryptPassword(device, password):
    tmpfile = "/tmp/tmpfile"
    defpw = "nvidia3d"
    defslot = "0"

    # Sanity check parameters
    if not device or not password:
        sys.stderr.write("Invalid function parameters\n")
        return False

    # Test the default passphrase before proceeding
    (status, output) = subprocess.getstatusoutput(
        "printf \"%s\" | /usr/sbin/cryptsetup luksOpen --test-passphrase " \
        "--key-slot %s %s" % \
        (defpw, defslot, device))

    if status != 0:
        sys.stderr.write("Test of default password failed\n")
        if os.path.exists(tmpfile):
            os.remove(tmpfile)
        return False

    # Set the user-provided password
    with open(tmpfile, "w") as f:
        f.write(password)

    # Add user-specified password to crypt
    (status, output) = subprocess.getstatusoutput(
        "printf \"%s\" | /usr/sbin/cryptsetup luksAddKey %s %s" % \
        (defpw, device, tmpfile))

    if status != 0:
        sys.stderr.write("Failed to add new password\n")
        if os.path.exists(tmpfile):
            os.remove(tmpfile)
        return False

    # Delete the default password from slot 0
    (status, output) = subprocess.getstatusoutput(
        "printf \"%s\" | /usr/sbin/cryptsetup luksKillSlot %s %s" % \
        (password, device, defslot))

    if status != 0:
        sys.stderr.write("Failed to remove default password\n")
        if os.path.exists(tmpfile):
            os.remove(tmpfile)
        return False

    if os.path.exists(tmpfile):
        os.remove(tmpfile)
    return True

def find_crypt_root():
    dm_base = ""
    dm_link = ""
    physroot = ""
    fmnt_cmd = "findmnt -fn -o SOURCE /"
    result = False

    try:
        (status, output) = subprocess.getstatusoutput(fmnt_cmd)
        dm_link = os.readlink(output)
        if (status != 0) or (not dm_link):
            sys.stderr.write("Failed to set dm_link\n")
            return (result, physroot)
    except Exception as e:
        sys.stderr.write(str(e))
        return (result, physroot)

    dm_base = os.path.basename(dm_link)
    for dev in os.listdir("/sys/class/block/"):
        fullpath = "/sys/class/block/" + dev + "/holders/" + dm_base
        if os.path.exists(fullpath):
            physroot="/dev/" + dev
            result = True

    sys.stdout.write("Result is %s, physroot is %s\n" % (result, physroot))
    return (result, physroot)

def dev_has_luks_header(cryptdev):
    cmd = "/usr/sbin/cryptsetup luksDump " + cryptdev

    (status, _) = subprocess.getstatusoutput(cmd)
    if status != 0:
        sys.stderr.write("Could not find LUKS header\n")
        return False

    sys.stdout.write("LUKS header found\n")
    return True

def crypt_is_using_default_password(cryptdev):
    defpw = "nvidia3d"
    defslot = "0"
    cmd = ("printf \"" + defpw +"\" | /usr/sbin/cryptsetup luksOpen " +
           "--test-passphrase --key-slot " + defslot + " " + cryptdev)

    (status, _) = subprocess.getstatusoutput(cmd)
    if status != 0:
        sys.stderr.write("Not using default password\n")
        return False

    sys.stdout.write("Found default password in slot %s\n" % (defslot))
    return True

def get_min_passwd_len():
    # Always require at least 8 characters
    return 8

class PageBase(plugin.PluginUI):
    def __init__(self):
        self.suffix = misc.dmimodel()
        self.allow_password_empty = False

    def get_device(self):
        """Get the crypt device."""
        raise NotImplementedError('get_device')

    def get_password(self):
        """Get the crypt password."""
        raise NotImplementedError('get_password')

    def get_verified_password(self):
        """Get the crypt password confirmation."""
        raise NotImplementedError('get_verified_password')

    def password_error(self, msg):
        """The selected password was bad."""
        raise NotImplementedError('password_error')

    def clear_errors(self):
        pass

    def info_loop(self, *args):
        """Verify user input."""
        pass

class PageGtk(PageBase):
    plugin_title = 'partman-crypto/text/specify_keytype'

    def __init__(self, controller, *args, **kwargs):
        self.physroot = ""

        # Early return so that we don't render this page if any of
        # the following are true:
        #   1) We fail to find the root device
        #   2) Crypt doesn't have a LUKS header
        #   3) The default password is not being used in slot 0
        (result, self.physroot) = find_crypt_root()
        if not result or not self.physroot:
            sys.stdout.write("%s: PageGtk early return\n" % (NAME))
            return

        if not dev_has_luks_header(self.physroot):
            sys.stdout.write("%s: PageGtk early return\n" % (NAME))
            return

        if not crypt_is_using_default_password(self.physroot):
            sys.stdout.write("%s: PageGtk early return\n" % (NAME))
            return

        from gi.repository import Gio, Gtk

        PageBase.__init__(self, *args, **kwargs)
        self.resolver = Gio.Resolver.get_default()
        self.controller = controller

        builder = Gtk.Builder()
        self.controller.add_builder(builder)
        builder.add_from_file(os.path.join(
            os.environ['UBIQUITY_GLADE'], 'stepNVcryptpasswd.ui'))
        builder.connect_signals(self)
        self.page = builder.get_object('stepNVcryptpasswd')
        self.password = builder.get_object('password')
        self.verified_password = builder.get_object('verified_password')
        self.password_error_label = builder.get_object('password_error_label')

        self.password_ok = builder.get_object('password_ok')
        self.password_strength = builder.get_object('password_strength')
        self.password_min_length = get_min_passwd_len()

        # Dodgy hack to let us center the contents of the page without it
        # moving as elements appear and disappear, specifically the full name
        # okay check icon and the hostname error messages.
        paddingbox = builder.get_object('paddingbox')

        def func(box):
            box.get_parent().child_set_property(box, 'expand', False)
            box.set_size_request(box.get_allocation().width / 2, -1)

        paddingbox.connect('realize', func)

        # Some signals need to be connected by hand so that we have the
        # handler ids.

        # The UserSetup component takes care of preseeding passwd/user-uid.
        misc.execute_root('apt-install', 'oem-config-gtk', 'oem-config-slideshow-ubuntu')

        self.resolver_ok = True
        self.plugin_widgets = self.page

    # Functions called by the Page.
    def get_device(self):
        return self.physroot

    def get_password(self):
        return self.password.get_text()

    def get_verified_password(self):
        return self.verified_password.get_text()

    def password_error(self, msg):
        self.password_strength.hide()
        m = '<small><span foreground="darkred"><b>%s</b></span></small>' % msg
        self.password_error_label.set_markup(m)
        self.password_error_label.show()

    def clear_errors(self):
        self.password_error_label.hide()

    # Callback functions.

    def info_loop(self, widget):
        """check if all entries from Identification screen are filled. Callback
        defined in ui file."""

        # Do some initial validation.  We have to process all the widgets so we
        # can know if we can really show the next button.  Otherwise we'd show
        # it on any field being valid.
        complete = True

        # Check password length first
        if len(self.get_password()) < self.password_min_length:
            self.password_ok.hide()
            self.clear_errors()
            self.password_error("Shorter than %s characters" % \
                str(self.password_min_length))
            password_ok = False
        else:
            self.password_ok.hide()
            self.clear_errors()
            password_ok = validation.gtk_password_validate(
                self.controller,
                self.password,
                self.verified_password,
                self.password_ok,
                self.password_error_label,
                self.password_strength,
                self.allow_password_empty,
            )

        complete = complete and password_ok

        self.controller.allow_go_forward(complete)


class PageKde(PageBase):
    plugin_breadcrumb = 'ubiquity/text/breadcrumb_user'

    def __init__(self, controller, *args, **kwargs):
        PageBase.__init__(self, *args, **kwargs)
        self.controller = controller

        from PyQt5 import uic
        from PyQt5.QtGui import QPixmap

        self.plugin_widgets = uic.loadUi(
            '/usr/share/ubiquity/qt/stepNVcryptpasswd.ui')
        self.page = self.plugin_widgets

        if self.controller.oem_config:
            self.page.login_pass.hide()


        # The UserSetup component takes care of preseeding passwd/user-uid.
        misc.execute_root('apt-install', 'oem-config-kde')

        warningIcon = QPixmap(
            "/usr/share/icons/oxygen/48x48/status/dialog-warning.png")
        self.page.password_error_image.setPixmap(warningIcon)

        self.clear_errors()

        # self.page.password.textChanged[str].connect(self.on_password_changed)
        # self.page.verified_password.textChanged[str].connect(
        #    self.on_verified_password_changed)
        self.page.login_pass.clicked[bool].connect(self.on_login_pass_clicked)
        self.page.login_auto.clicked[bool].connect(self.on_login_auto_clicked)

        self.page.password_debug_warning_label.setVisible(
            'UBIQUITY_DEBUG' in os.environ)

    def on_password_changed(self):
        pass

    def on_verified_password_changed(self):
        pass

    def get_device(self):
        return None

    def get_password(self):
        return str(self.page.password.text())

    def get_verified_password(self):
        return str(self.page.verified_password.text())

    def password_error(self, msg):
        self.page.password_error_reason.setText(msg)
        self.page.password_error_image.show()
        self.page.password_error_reason.show()

    def clear_errors(self):
        self.page.password_error_image.hide()
        self.page.password_error_reason.hide()


class PageDebconf(PageBase):
    plugin_title = 'partman-crypto/text/specify_keytype'

    def __init__(self, controller, *args, **kwargs):
        self.controller = controller


class PageNoninteractive(PageBase):
    def __init__(self, controller, *args, **kwargs):
        PageBase.__init__(self, *args, **kwargs)
        self.controller = controller
        self.password = ''
        self.verifiedpassword = ''
        self.device = ''
        self.console = self.controller._wizard.console

    def get_password(self):
        """Get the user's password."""
        return self.controller.dbfilter.db.get('cryptpasswd/crypt-password')

    def get_verified_password(self):
        """Get the user's password confirmation."""
        return self.controller.dbfilter.db.get('cryptpasswd/crypt-password-again')

    def password_error(self, msg):
        """The selected password was bad."""
        print('\nBad password: %s' % msg, file=self.console)
        import getpass
        self.password = getpass.getpass('Password: ')
        self.verifiedpassword = getpass.getpass('Password again: ')

    def clear_errors(self):
        pass


class Page(plugin.Plugin):
    def prepare(self, unfiltered=False):
        if FrontendIsUbiquity():
            return ['/usr/share/nvidia/nv-crypt-passwd.sh']

        # Load debconf templates
        (_, _) = subprocess.getstatusoutput(
            '/usr/share/nvidia/nv-crypt-passwd.sh load_only')

        # We need to call info_loop as we switch to the page so the next button
        # gets disabled.
        self.ui.info_loop(None)

        # End here, don't return a command to fall through to
        # Page[Gtk|Kde] cases
        return

    def ok_handler(self):
        self.ui.clear_errors()

        device = self.ui.get_device()
        password = self.ui.get_password()
        password_confirm = self.ui.get_verified_password()

        self.preseed('cryptpasswd/device', device)
        self.preseed('cryptpasswd/crypt-password', password)
        self.preseed('cryptpasswd/crypt-password-again', password_confirm)
        if self.ui.controller.oem_config:
            self.preseed('passwd/user-uid', '29999')
        else:
            self.preseed('passwd/user-uid', '')

        plugin.Plugin.ok_handler(self)

    def error(self, priority, question):
        if question.startswith('cryptpasswd/crypt-password-'):
            self.ui.password_error(self.extended_description(question))
        else:
            self.ui.error_dialog(
                self.description(question),
                self.extended_description(question))
        return plugin.Plugin.error(self, priority, question)

class Install(plugin.InstallPlugin):
    def install(self, target, progress, *args, **kwargs):
        import syslog

        # By the time we get to the install phase, console output in
        # GTK mode no longer gets saved to oem-config.log.  So instead
        # we'll use the following log function
        def log_to_syslog(msg):
            syslog.syslog("oem-config: %s: %s" % (NAME, msg))

        cryptdev = ""
        cryptpass = ""

        try:
            # Retrieve user specified password
            cryptdev = self.db.get('cryptpasswd/device')
            cryptpass = self.db.get('cryptpasswd/crypt-password')
            log_to_syslog("Setting password on %s" % (cryptdev))

            if cryptdev and cryptpass:
                setCryptPassword(cryptdev, cryptpass)
            else:
                log_to_syslog("Skipped setCryptPassword")
        except:
            log_to_syslog("Failed to retrieve crypt information")

        return plugin.InstallPlugin.install(
            self, target, progress, *args, **kwargs)
