#!/usr/bin/env python3

# Copyright 2021 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.


import argparse
import collections
import logging
import os
import re
import shlex
import subprocess
import sys

LOG = logging.getLogger(__name__)

UNIT_BYTES = {
    'MB': 1000000,
    'MiB': 1048576,
    'GB': 1000000000,
    'GiB': 1073741824
}
UNITS = ['%']
UNITS.extend(UNIT_BYTES.keys())
AMOUNT_UNIT_RE = re.compile('^([0-9]+)(%s)$' % '|'.join(UNITS))

# Only create growth partition if there is at least 1GiB available
MIN_DISK_SPACE_BYTES = UNIT_BYTES['GiB']

# Default LVM physical extent size is 4MiB
PHYSICAL_EXTENT_BYTES = 4 * UNIT_BYTES['MiB']


class Command(object):
    """ An object to represent a command to run with associated comment """

    cmd = None

    comment = None

    def __init__(self, cmd, comment=None):
        self.cmd = cmd
        self.comment = comment

    def __repr__(self):
        if self.comment:
            return "\n# %s\n%s" % (self.comment, printable_cmd(self.cmd))
        return printable_cmd(self.cmd)


def parse_opts(argv):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description='''
Grow the logical volumes in an LVM group to take available device space.

The positional arguments specify how available space is allocated. They
have the format <volume>=<amount><unit> where:

<volume> is the label or the mountpoint of the logical volume
<amount> is an integer growth amount in the specified unit
<unit> is one of the supported units

Supported units are:

% percentage of available device space before any changes are made
MiB mebibyte (1048576 bytes)
GiB gibibyte (1073741824 bytes)
MB megabyte (1000000 bytes)
GB gigabyte (1000000000 bytes)

Each argument is processed in order and the requested amount is allocated
to each volume until the disk is full. This means that if space is
overallocated, the last volumes may only grow by the remaining space, or
not grow at all, and a warning will be printed. When space is underallocated
the remaining space will be given to the root volume (mounted at /).

The currently supported partition layout is:
- Exactly one of the partitions containing an LVM group
- The disk having unpartitioned space to grow with
- The LVM logical volumes being formatted with XFS filesystems

Example usage:

growvols /var=80% /home=20GB

growvols --device sda --group vg img-rootfs=20% fs_home=20% fs_var=60%

''')
    parser.add_argument('grow_vols',
                        nargs="*",
                        metavar="<volume>=<amount><unit>",
                        default=os.environ.get('GROWVOLS_ARGS', '').split(),
                        help='A label or mountpoint, and the proportion to '
                             'grow it by. Defaults to $GROWVOLS_ARGS. '
                             'For example: /home=80% img-rootfs=20%')
    parser.add_argument('--group', metavar='GROUP',
                        default=os.environ.get('GROWVOLS_GROUP'),
                        help='The name of the LVM group to extend. Defaults '
                             'to $GROWVOLS_GROUP or the discovered group '
                             'if only one exists.')
    parser.add_argument('--device', metavar='DEVICE',
                        default=os.environ.get('GROWVOLS_DEVICE'),
                        help='The name of the disk block device to grow the '
                             'volumes in (such as "sda"). Defaults to the '
                             'disk containing the root mount.')

    parser.add_argument(
        '--exit-on-no-grow',
        action='store_true',
        help="Exit with error code 2 if no volume could be grown. ",
        default=False)

    parser.add_argument(
        '-d', '--debug',
        dest="debug",
        action='store_true',
        help="Print debugging output.",
        required=False)
    parser.add_argument(
        '-v', '--verbose',
        dest="verbose",
        action='store_true',
        help="Print verbose output.",
        required=False)
    parser.add_argument(
        '--noop',
        dest="noop",
        action='store_true',
        help="Return the configuration commands, without applying them.",
        required=False)
    parser.add_argument(
        '-y', '--yes',
        help='Skip yes/no prompt (assume yes).',
        default=False,
        action="store_true")

    opts = parser.parse_args(argv[1:])

    return opts


def configure_logger(verbose=False, debug=False):
    LOG_FORMAT = '[%(levelname)s] %(message)s'
    log_level = logging.WARN

    if debug:
        log_level = logging.DEBUG
    elif verbose:
        log_level = logging.INFO

    logging.basicConfig(format=LOG_FORMAT,
                        level=log_level)


def printable_cmd(cmd):
    """Convert a command list to a log printable string"""
    cmd_quoted = [shlex.quote(c) for c in cmd]
    return ' '.join(cmd_quoted)


def convert_bytes(num):
    """Format a bytes amount with units MB, GB etc"""
    step_unit = 1000.0

    for x in ['B', 'KB', 'MB', 'GB', 'TB']:
        if num < step_unit:
            return "%d%s" % (num, x)
        num /= step_unit


def execute(cmd):
    """Run a command"""
    LOG.info('Running: %s', printable_cmd(cmd))

    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True
    )
    out, err = process.communicate()
    if process.returncode != 0:
        error_msg = (
            'Running command failed: cmd "{}", stdout "{}",'
            ' stderr "{}"'.format(
                ' '.join(cmd),
                out,
                err
            )
        )
        LOG.error(error_msg)
        raise Exception(error_msg)
    LOG.debug('Result: %s', out)
    return out


def parse_shell_vars(output):
    """Parse tags from the lsblk/blkid output.

    Parses format KEY="VALUE" KEY2="VALUE2".

    :return: a generator yielding dicts with information from each line.
    """
    for line in output.strip().split('\n'):
        if line.strip():
            yield {key: value for key, value in
                   (v.split('=', 1) for v in shlex.split(line))}


def find_device(devices, keys, value):
    """Search devices list for one device that has a key matching the value"""
    if isinstance(keys, str):
        keys = [keys]
    for device in devices:
        for key in keys:
            if device.get(key) == value:
                return device


def find_disk(opts, devices):
    """Find the disk device given a device option, or searching from mount /"""
    if opts.device:
        device = find_device(devices, ['KNAME', 'NAME'], opts.device)
        if not device:
            raise Exception('Could not find specified --device: %s'
                            % opts.device)
    else:
        device = find_device(devices, 'MOUNTPOINT', '/')
        if not device:
            raise Exception('Could not find mountpoint for /')

        while True:
            device = find_device(devices, 'KNAME', device['PKNAME'])
            if not device:
                break
            if device['TYPE'] == 'disk':
                break
            if not device['PKNAME']:
                break
        if not device:
            raise Exception('Could not detect disk device')

    if device['TYPE'] != 'disk':
        raise Exception('Expected a device with TYPE="disk", got: %s'
                        % device['TYPE'])
    return device


def find_space(disk_name):
    LOG.info('Finding spare space to grow into')
    dev_path = '/dev/%s' % disk_name
    largest = execute([
        'sgdisk',
        '--first-aligned-in-largest',
        '--end-of-largest',
        dev_path])
    start_end = largest.strip().split('\n')
    sector_start = int(start_end[0])
    sector_end = int(start_end[1])
    size_sectors = sector_end - sector_start
    return sector_start, sector_end, size_sectors


def find_devices():
    LOG.info('Finding all block devices')
    lsblk = execute([
        'lsblk',
        '-Po',
        'kname,pkname,name,label,type,fstype,mountpoint'])
    return list(parse_shell_vars(lsblk))


def find_group(opts):
    LOG.info('Finding LVM volume group')
    vgs = execute(['vgs', '--noheadings', '--options', 'vg_name'])
    vg_names = set([vg.strip() for vg in vgs.split('\n') if vg.strip()])
    if opts.group:
        if opts.group not in vg_names:
            raise Exception('Could not find specified --group: %s'
                            % opts.group)
        return opts.group
    if len(vg_names) > 1:
        raise Exception('More than one volume group, specify one to '
                        'use with --group: %s' % ', '.join(sorted(vg_names)))
    if len(vg_names) < 1:
        raise Exception('No volume groups found')
    return vg_names.pop()


def find_next_partnum(devices, disk_name):
    return len([d for d in devices if d['PKNAME'] == disk_name]) + 1


def find_next_device_name(devices, disk_name, partnum):
    existing_partnum = partnum - 1

    # try partition scheme for SATA etc, then NVMe
    for part_template in '%s%d', '%sp%d':
        part = part_template % (disk_name, existing_partnum)
        LOG.debug('Looking for device %s' % part)
        if find_device(devices, 'KNAME', part):
            return part_template % (disk_name, partnum)

    raise Exception('Could not find partition naming scheme for %s'
                    % disk_name)


def prompt_user_for_confirmation(message):
    """Prompt user for a y/N confirmation

    Use this function to prompt the user for a y/N confirmation
    with the provided message. The [y/N] should be included in
    the provided message to this function to indicate the expected
    input for confirmation. You can customize the positive response if
    y/N is not a desired input.

    :param message: Confirmation string prompt
    :param positive_response: Beginning character for a positive user input
    :return: boolean true for valid confirmation, false for all others
    """
    try:
        if not sys.stdin.isatty():
            LOG.error('User interaction required, cannot confirm.')
            return False
        else:
            sys.stdout.write(message)
            sys.stdout.flush()
            prompt_response = sys.stdin.readline().lower()
            if not prompt_response.startswith('y'):
                print('Taking no action.')
                return False
            LOG.info('User confirmed action.')
            return True
    except KeyboardInterrupt:  # ctrl-c
        print('User did not confirm action (ctrl-c) so taking no action.')
    except EOFError:  # ctrl-d
        print('User did not confirm action (ctrl-d) so taking no action.')
    return False


def amount_unit_to_extent(amount_unit, total_size_bytes, remaining_bytes):
    m = AMOUNT_UNIT_RE.match(amount_unit)
    if not m:
        raise Exception('Value for <amount><unit> not valid: %s'
                        % amount_unit)

    amount = int(m.group(1))
    unit = m.group(2)
    if unit in UNIT_BYTES:
        bytes = amount * UNIT_BYTES[unit]
        LOG.debug('%s is %s' % (amount_unit, convert_bytes(bytes)))
    else:
        bytes = int((amount / 100.0) * total_size_bytes)
        LOG.debug('%s of %s is %s'
                  % (amount_unit,
                     convert_bytes(total_size_bytes),
                     convert_bytes(bytes)))

    if remaining_bytes <= 0:
        LOG.debug('No remaining space available to allocate %s' % amount_unit)
        bytes = 0
    elif bytes > remaining_bytes:
        LOG.debug('Cannot allocate all of %s, only %s remaining space'
                  % (convert_bytes(bytes), convert_bytes(remaining_bytes)))
        bytes = remaining_bytes

    # reduce to align on physical extent size
    bytes = bytes - bytes % PHYSICAL_EXTENT_BYTES

    return bytes, remaining_bytes - bytes


def find_grow_vols(opts, devices, group, total_size_bytes):
    grow_vols = collections.OrderedDict()
    remaining_bytes = total_size_bytes

    arg_error = ('<grow_vols> must be of the format '
                 '<volume>=<amount><unit>, '
                 'for example: /home=100% or fs_home=100%')
    grow_vols_opt = list(gv for gv in opts.grow_vols if gv)
    # append an implicit /=100% to grow root by any underallocation
    grow_vols_opt.append('/=100%')

    for gv in grow_vols_opt:
        if '=' not in gv:
            raise Exception(arg_error)
        devname, amount_unit = gv.split('=', 1)
        device = find_device(devices, ['LABEL', 'MOUNTPOINT'], devname)
        if not device:
            raise Exception('Could not find device %s for argument: %s'
                            % (devname, gv))

        if device['TYPE'] != 'lvm':
            raise Exception('Device %(NAME)s is not of type lvm: %(TYPE)s'
                            % device)

        if device['FSTYPE'] != 'xfs':
            raise Exception('Device %(NAME)s is not of fstype xfs: %(FSTYPE)s'
                            % device)

        size_bytes, remaining_bytes = amount_unit_to_extent(
            amount_unit, total_size_bytes, remaining_bytes)

        if size_bytes == 0:
            continue

        # remove the group- prefix from the name
        name = device['NAME'][len(group) + 1:]
        volume_path = '/dev/%s/%s' % (group, name)

        grow_vols[volume_path] = size_bytes

    return grow_vols


def find_sector_size(disk_name):
    sector_path = '/sys/block/%s/queue/logical_block_size' % disk_name

    LOG.info('Reading sector size from %s' % sector_path)
    with open(sector_path) as f:
        size = int(f.read().strip())
        LOG.info('Sector size of %s is %s'
                 % (disk_name, convert_bytes(size)))
        return size


def main(argv):
    opts = parse_opts(argv)
    configure_logger(opts.verbose, opts.debug)

    devices = find_devices()
    disk = find_disk(opts, devices)
    disk_name = disk['KNAME']

    sector_bytes = find_sector_size(disk_name)
    sector_start, sector_end, size_sectors = find_space(disk_name)

    min_required_sectors = MIN_DISK_SPACE_BYTES // sector_bytes

    size_bytes = size_sectors * sector_bytes
    if size_sectors < min_required_sectors:
        msg = ('Need at least %s sectors to expand into, there '
               'is only: %s' % (min_required_sectors, size_sectors))
        if opts.exit_on_no_grow:
            LOG.error(msg)
            return 2

        LOG.info(msg)
        return 0

    group = find_group(opts)
    partnum = find_next_partnum(devices, disk_name)
    devname = find_next_device_name(devices, disk_name, partnum)
    dev_path = '/dev/%s' % devname
    grow_vols = find_grow_vols(opts, devices, group, size_bytes)

    commands = []

    commands.append(Command([
        'sgdisk',
        '--new=%d:%d:%d' % (partnum, sector_start, sector_end),
        '--change-name=%d:growvols' % partnum,
        '/dev/%s' % disk_name
    ], 'Create new partition %s' % devname))

    commands.append(Command(
        ['partprobe'],
        'Inform the OS of partition table changes'))

    commands.append(Command([
        'pvcreate',
        dev_path
    ], 'Initialize %s for use by LVM' % devname))

    commands.append(Command([
        'vgextend',
        group,
        dev_path
    ], 'Add physical volume %s to group %s' % (devname, group)))

    for volume_path, size_bytes in grow_vols.items():
        if size_bytes > 0:
            commands.append(Command([
                'lvextend',
                '--size',
                '+%sB' % size_bytes,
                volume_path,
                dev_path
            ], 'Add %s to logical volume %s' % (convert_bytes(size_bytes),
                                                volume_path)))

    for volume_path, size_bytes in grow_vols.items():
        if size_bytes > 0:
            commands.append(Command([
                'xfs_growfs',
                volume_path
            ], 'Grow XFS filesystem for %s' % volume_path))

    if opts.noop:
        print('The following commands would have run:')
    else:
        print('The following commands will be run:')

    for cmd in commands:
        print(cmd)
    sys.stdout.flush()

    if opts.noop:
        return 0

    yes = opts.yes
    if not yes:
        yes = prompt_user_for_confirmation(
            '\nAre you sure you want to run these commands? [y/N] ')
    if yes:
        for cmd in commands:
            execute(cmd.cmd)

    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv))
