#!/usr/libexec/platform-python

# Copyright (C) 2011 Oracle. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, version 2.  This program is distributed in the hope that it will
# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.  You should have received a copy of the GNU
# General Public License along with this program; if not, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 021110-1307, USA.

import glob
import json
import logging
import logging.handlers
import optparse
import os
import pprint
import select
import sys
import tempfile
import termios
import threading
import time
import xml.dom.minidom

from templateconfig.common import (run_cmd, error, msg, check_password,
         get_target_path, get_priority, check_priority, get_name, check_name)


PARAM = {}
PARAM_CHECK_STOP = False
USER_INTERRUPT = False
TIMEOUT = 10.0


def init_log(logfile, loglevel):
    max_bytes = 5 * 1024 * 1024
    backup_count = 5
    date_format = '%Y-%m-%d %H:%M:%S'
    file_format = ('[%(asctime)s %(process)d] %(levelname)s '
                   '(%(module)s:%(lineno)d) %(message)s')
    logger = logging.getLogger()
    logger.propagate = 0
    handler = logging.handlers.RotatingFileHandler(filename=logfile,
                                                   maxBytes=max_bytes,
                                                   backupCount=backup_count)
    handler.setFormatter(logging.Formatter(file_format, date_format))
    logger.addHandler(handler)
    logger.setLevel(loglevel)


def run_scripts(target, output, script=None):
    if script:
        scripts = glob.glob('%s/??%s' % (get_target_path(target), script))
        if not scripts:
            raise Exception('No such script found: %s' % script)
    else:
        scripts = glob.glob('%s/*' % get_target_path(target))
    for s in sorted(scripts):
        logging.info('run script: %s, target: %s', s, target)
        param = run_cmd([s, target], json.dumps(PARAM))
        param = json.loads(param)
        for (k, v) in list(param.items()):
            if k not in PARAM or PARAM[k] != v:
                output.write((json.dumps({k:v})).encode('utf-8') + b'\n')
        PARAM.update(param)


def do_enumerate(target, script=None):
    expected = []
    if script:
        scripts = glob.glob('%s/??%s' % (get_target_path(target), script))
        if not scripts:
            raise Exception('No such script found: %s' % script)
    else:
        scripts = glob.glob('%s/*' % get_target_path(target))
    for s in sorted(scripts):
        priority = get_priority(s)
        check_priority(priority)
        name = get_name(s)
        check_name(name)
        try:
            param = run_cmd([s, '--enumerate', target]).strip()
        except Exception as e:
            logging.debug('Error enumerating parameters: %s', e)
            continue
        if not param:
            continue
        try:
            param = json.loads(param)
        except Exception as e:
            logging.debug('Error loading parameters: %s %s', param, e)
            continue
        if param:
            expected.append((priority, name, param))
    return expected


def check_param(param, expected):
    for (_, name, params) in expected:
        for item in params:
            if item.get('required') and item['key'] not in param:
                raise Exception('missing value for key "%s" of script "%s"' % (item['key'], name))


def required_param_complete(expected):
    try:
        check_param(PARAM, expected)
    except Exception:
        return False
    return True


def read_input(fd, timeout):
    """Get input from file descriptor with timeout.

    input can be:

    - None: time out
    - '\n': 'Enter' pressed
    - '': 'Ctrl-D' pressed
    - '<str>': input some characters and press 'Ctrl-D'
    - '<str>\n': input some characters and press 'Enter'
    """
    (r, w, x) = select.select([fd], [], [], timeout)
    if (r):
        return fd.readline()


def check_fd(fd, expected):
    fileno = fd.fileno()
    logging.info('check fd (%s) started', fileno)
    global USER_INTERRUPT
    while not PARAM_CHECK_STOP:
        try:
            if required_param_complete(expected):
                break
            param = read_input(fd, TIMEOUT)
            if(isinstance(param, (bytes,bytearray))):
                param = param.decode('utf-8') 
            logging.debug('check fd (%s) get param: %s', fileno, repr(param))
            if param is not None:
                if param == '':
                    logging.debug('check fd (%s) get EOF', fileno)
                    break
                if param.endswith('\n'):
                    param = param[:-1]
                if param:
                    try:
                        param = json.loads(param)
                    except Exception as e:
                        logging.debug('check fd (%s) load param error: %s', fileno, e)
                        continue
                    if isinstance(param, dict):
                        logging.debug('check fd (%s) update param: %s', fileno, param)
                        PARAM.update(param)
        except KeyboardInterrupt:
            logging.debug('check fd (%s) user interrupt received', fileno)
            USER_INTERRUPT = True
            break
        except Exception as e:
            logging.debug('check fd (%s) error: %s', fileno, e)
            break
    logging.info('check fd (%s) finished', fileno)


def get_disks():
    disks = []
    f = open('/proc/partitions')
    for entry in f:
        info = entry.split()
        if len(info) != 4:
            continue
        major = info[0]
        name = info[3]
        if not major.isdigit():
            continue
        if name[-1] in '0123456789':
            continue
        if name.startswith('loop'):
            continue
        if name.startswith('dm-'):
            continue
        disks.append(name)
    f.close()
    return disks


def get_mount_point(dev):
    mounts = open('/proc/mounts')
    try:
        for line in mounts:
            device = line.split()[0]
            mount_point = line.split()[1]
            if dev == device:
                return mount_point
    finally:
        mounts.close()


def mount(dev, path, option=''):
    if option:
        run_cmd(['mount', option, dev, path])
    else:
        run_cmd(['mount', dev, path])


def umount(path):
    run_cmd(['umount', path])


def parse_ovf_env_file(filename):
    param = {}
    document = xml.dom.minidom.parse(filename)
    document.normalize()
    environment = document.documentElement
    sections = []
    for section in environment.getElementsByTagNameNS('*', 'PropertySection'):
        sections.append(section)
    for entity in environment.getElementsByTagNameNS('*', 'Entity'):
        for section in entity.getElementsByTagNameNS('*', 'PropertySection'):
            sections.append(section)
    for section in sections:
        for node in section.getElementsByTagNameNS('*', 'Property'):
            key = node.getAttribute('ovfenv:key')
            value = node.getAttribute('ovfenv:value')
            param[key] = value
    return param


def parse_iso():
    param = {}
    tmpdir = tempfile.mkdtemp(prefix='cdrom.')
    cdroms = ['/dev/cdrom%s' % suffix for suffix in ['', 0, 1, 2, 3]]
    disks =  ['/dev/%s' % disk for disk in get_disks()]
    try:
        for cdrom in cdroms + disks:
            if not os.path.exists(cdrom):
                continue
            mount_point = get_mount_point(cdrom)
            if mount_point:
                env_file = os.path.join(mount_point, 'ovf-env.xml')
                if os.path.exists(env_file):
                    param = parse_ovf_env_file(env_file)
                    break
                continue
            try:
                mount(cdrom, tmpdir)
            except Exception:
                continue
            try:
                env_file = os.path.join(tmpdir, 'ovf-env.xml')
                if os.path.exists(env_file):
                    param = parse_ovf_env_file(env_file)
                    break
            finally:
                try:
                    umount(tmpdir)
                except Exception:
                    pass
    finally:
        os.rmdir(tmpdir)
    return param


def check_iso(expected):
    logging.info('check ISO started')
    global USER_INTERRUPT
    while not PARAM_CHECK_STOP:
        try:
            if required_param_complete(expected):
                break
            param = parse_iso()
            logging.debug('check ISO get param: %s', param)
            if param:
                PARAM.update(param)
            time.sleep(TIMEOUT)
        except KeyboardInterrupt:
            logging.debug('check ISO user interrupt received')
            USER_INTERRUPT = True
            break
        except Exception as e:
            logging.debug('check ISO error: %s', e)
            break
    logging.info('check ISO finished')


def console_echo(on):
    fd = sys.stdin.fileno()
    if os.isatty(fd):
        attr = termios.tcgetattr(fd)
        if on:
            attr[3] |= termios.ECHO
        else:
            attr[3] &= ~termios.ECHO
        termios.tcsetattr(fd, termios.TCSADRAIN, attr)


def check_console(expected):
    logging.info('check console started')
    if not expected:
        logging.info('check console finished: no parameter is needed')
        return
    global PARAM_CHECK_STOP
    global USER_INTERRUPT
    param_list = [(script, param)
                  for (_, script, params) in expected
                  for param in params]
    index = 0
    param_length = len(param_list)
    inpt = ''
    while not PARAM_CHECK_STOP:
        try:
            if required_param_complete(expected):
                break
            script, param = param_list[index]
            if param.get('hidden') and not param.get('required'):
                index = (index + 1) % param_length
                continue
            depends = param.get('depends')
            if depends and PARAM.get(depends) is None:
                index = (index + 1) % param_length
                continue
            requires = param.get('requires')
            if requires and PARAM.get(requires[0]) not in requires[1]:
                index = (index + 1) % param_length
                continue
            if inpt is not None:
                msg('%s: %s: ', script, param.get('description', ''))
            if param.get('password'):
                console_echo(False)
            inpt = read_input(sys.stdin, TIMEOUT)
            if(isinstance(inpt, (bytes, bytearray))):
                inpt = inpt.decode('utf-8')
            if param.get('password'):
                console_echo(True)
            logging.debug('check console get value: %s for %s', repr(inpt), param['key'])
            if inpt is None:
                continue
            if inpt in ['\n', '']:
                if param.get('required'):
                    msg('\nInvalid input. Please input again.\n')
                    continue
                if param.get('password'):
                    msg('\n')
            if inpt.endswith('\n'):
                inpt = inpt[:-1]
            if inpt:
                if param.get('password'):
                    try:
                        check_password(inpt)
                    except ValueError as err:
                        msg('\nInvalid password: %s. Please input again.\n', err)
                        continue
                    msg('\n')
                choices = param.get('choices')
                if choices and inpt not in choices:
                    msg('Invalid input: should be one of %s. Please input again.\n', choices)
                    continue
                PARAM[param['key']] = inpt
            logging.debug('check console update param: %s = %s', param['key'], PARAM.get(param['key']))
            index = (index + 1) % param_length
        except KeyboardInterrupt:
            logging.debug('check console user interrupt received')
            USER_INTERRUPT = True
            PARAM_CHECK_STOP = True
            console_echo(True)
        except Exception as e:
            logging.debug('check console error: %s', e)
            PARAM_CHECK_STOP = True
            console_echo(True)
    logging.info('check console finished')


def build_param(inputfd, expected):
    t1 = threading.Thread(target=check_fd, args=(inputfd, expected))
    t2 = threading.Thread(target=check_iso, args=(expected,))
    t1.start()
    t2.start()
    # Give 3 seconds to the above two threads to process parameters. If all
    # required parameters are satisfied from OVMAPI or transport ISO, then no
    # need to show the console input screen.
    time.sleep(3)
    check_console(expected)
    t1.join()
    t2.join()


def main():
    usage = ('''%prog [option] target

Targets:
  configure, unconfigure, reconfigure, cleanup, suspend, resume, migrate, shutdown

Examples:
  %prog --enumerate configure
  %prog --enumerate --script network configure
  %prog --stdin configure
  %prog --stdin --script network configure
  %prog --console-input configure
  %prog --ovf-transport-iso configure
  %prog --input <infd> --output <outfd> configure''')
    parser = optparse.OptionParser(usage)
    parser.add_option('-e', '--enumerate', action='store_true',
                      help='enumerate parameters for target')
    parser.add_option('', '--human-readable', action='store_true',
                      help='print in human readable format when enumerate parameters')
    parser.add_option('-i', '--input', type='int',
                      help='input parameters from this file descriptor')
    parser.add_option('-o', '--output', type='int',
                      help='output parameters to this file descriptor')
    parser.add_option('', '--stdin', action='store_true',
                      help='build parameters from stdin')
    parser.add_option('', '--console-input', action='store_true',
                      help='build parameters from console input')
    parser.add_option('', '--ovf-transport-iso', action='store_true',
                      help='build parameters from OVF transport ISO')
    parser.add_option('-s', '--script',
                      help='specify script')
    parser.add_option('', '--logfile', default='/var/log/ovm-template-config.log',
                      help='specify log file')
    parser.add_option('', '--loglevel', default='INFO',
                      help='specify log level')
    (opts, args) = parser.parse_args()

    init_log(opts.logfile, logging.getLevelName(opts.loglevel.upper()))
    logging.info('ovm-template-config started: %s', sys.argv)

    if len(args) != 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    target = args[0]

    expected = do_enumerate(target, opts.script)
    logging.info('expected parameters: %s', expected)
    if opts.enumerate:
        if opts.human_readable:
            pprint.pprint(expected)
        else:
            print(json.dumps(expected))
        sys.exit(0)

    output = sys.stdout
    if opts.stdin:
        check_fd(sys.stdin, expected)
    elif opts.ovf_transport_iso:
        check_iso(expected)
    elif opts.console_input:
        check_console(expected)
    else:
        if opts.input is None or opts.output is None:
            error('input fd or output fd is not specified')
        inputfd = os.fdopen(opts.input, 'rb', 0)
        output = os.fdopen(opts.output, 'wb', 0)
        build_param(inputfd, expected)

    if USER_INTERRUPT:
        error('\nuser interrupt')

    try:
        check_param(PARAM, expected)
    except Exception as err:
        logging.error('error checking parameters: %s', err)
        error(str(err))

    try:
        run_scripts(target, output, opts.script)
    except Exception as err:
        logging.error('error running scripts.')
        logging.debug(err)
        error(str(err))

    logging.info('ovm-template-config finished')


if __name__ == '__main__':
    main()
