#!/usr/bin/python3

# Copyright (C) 2011 Oracle. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, version 2.  This program is distributed in the hope that it will
# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.  You should have received a copy of the GNU
# General Public License along with this program; if not, write to the Free
# Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 021110-1307, USA.

### BEGIN PLUGIN INFO
# name: ssh
# configure: 70
# cleanup: 30
# description: Script to configure template ssh.
### END PLUGIN INFO

import json
import os
import pwd
import shutil

from templateconfig.common import get_entry_list
from templateconfig.cli import main


def configure_ssh_user_keys(user, filename, key, mode='a'):
    try:
        userinfo = pwd.getpwnam(user)
    except KeyError:
        return
    uid = userinfo[2]
    gid = userinfo[3]
    homedir = userinfo[5]
    if not os.path.exists(homedir):
        return
    sshdir = os.path.join(homedir, '.ssh')
    if not os.path.exists(sshdir):
        os.mkdir(sshdir)
    filepath = os.path.join(sshdir, filename)
    fileobj = open(filepath, mode)
    fileobj.write('%s\n' % key)
    fileobj.close()
    os.chmod(filepath, 0o600)
    os.chmod(sshdir, 0o700)
    os.chown(filepath, uid, gid)
    os.chown(sshdir, uid, gid)


def configure_ssh_host_keys(private_key, public_key, key_type):
    if key_type == 'rsa':
        keyfile = 'ssh_host_rsa_key'
        keyfile_pub = 'ssh_host_rsa_key.pub'
    elif key_type == 'dsa':
        keyfile = 'ssh_host_dsa_key'
        keyfile_pub = 'ssh_host_dsa_key.pub'
    elif key_type == 'rsa1':
        keyfile = 'ssh_host_key'
        keyfile_pub = 'ssh_host_key.pub'
    else:
        raise Exception('Unknow host key type: %s' % key_type)
    filepath = os.path.join('/etc/ssh', keyfile)
    fileobj = open(filepath, 'w')
    fileobj.write('%s\n' % private_key)
    fileobj.close()
    os.chmod(filepath, 0o600)
    os.chown(filepath, 0, 0)
    filepath = os.path.join('/etc/ssh', keyfile_pub)
    fileobj = open(filepath, 'w')
    fileobj.write('%s\n' % public_key)
    fileobj.close()
    os.chmod(filepath, 0o644)
    os.chown(filepath, 0, 0)


def do_enumerate(target):
    param = []
    if target == 'configure':
        param += [{'key': 'com.oracle.linux.ssh.host-key',
                   'description': 'Host private rsa1 key for protocol version 1.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.host-key-pub',
                   'description': 'Host public rsa1 key for protocol version 1.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.host-rsa-key',
                   'description': 'Host private rsa key.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.host-rsa-key-pub',
                   'description': 'Host public rsa key.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.host-dsa-key',
                   'description': 'Host private dsa key.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.host-dsa-key-pub',
                   'description': 'Host public dsa key.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.user.0',
                   'description': 'Name of the user to add a key.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.authorized-keys.0',
                   'description': 'Authorized public keys.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.private-key.0',
                   'description': 'Private key for authentication.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.private-key-type.0',
                   'description': 'Private key type: rsa, dsa or rsa1.',
                   'hidden': True},
                  {'key': 'com.oracle.linux.ssh.known-hosts.0',
                   'description': 'Known hosts.',
                   'hidden': True}]
    return json.dumps(param)


def do_configure(param):
    param = json.loads(param)
    host_key = param.get('com.oracle.linux.ssh.host-key')
    host_key_pub = param.get('com.oracle.linux.ssh.host-key-pub')
    if host_key and host_key_pub:
        configure_ssh_host_keys(host_key, host_key_pub, 'rsa1')
    host_rsa_key = param.get('com.oracle.linux.ssh.host-rsa-key')
    host_rsa_key_pub = param.get('com.oracle.linux.ssh.host-rsa-key-pub')
    if host_rsa_key and host_rsa_key_pub:
        configure_ssh_host_keys(host_rsa_key, host_rsa_key_pub, 'rsa')
    host_dsa_key = param.get('com.oracle.linux.ssh.host-dsa-key')
    host_dsa_key_pub = param.get('com.oracle.linux.ssh.host-dsa-key-pub')
    if host_dsa_key and host_dsa_key_pub:
        configure_ssh_host_keys(host_dsa_key, host_dsa_key_pub, 'dsa')
    for (user, index) in get_entry_list(param, 'com.oracle.linux.ssh.user'):
        authorized_keys = param.get('com.oracle.linux.ssh.authorized-keys.%s' % index)
        # begin of backward compatibility: for ssh.key, should remove on 2013-02-05
        if not authorized_keys:
            authorized_keys = param.get('com.oracle.linux.ssh.key.%s' % index)
        # end of backward compatibility
        if authorized_keys:
            configure_ssh_user_keys(user, 'authorized_keys', authorized_keys)
        private_key = param.get('com.oracle.linux.ssh.private-key.%s' % index)
        private_key_type = param.get('com.oracle.linux.ssh.private-key-type.%s' % index)
        if private_key:
            if private_key_type == 'rsa':
                filename = 'id_rsa'
            elif private_key_type == 'dsa':
                filename = 'id_dsa'
            elif private_key_type == 'rsa1':
                filename = 'identity'
            else:
                raise Exception('Unknow private key type: %s' % private_key_type)
            configure_ssh_user_keys(user, filename, private_key, 'w')
        known_hosts = param.get('com.oracle.linux.ssh.known-hosts.%s' % index)
        if known_hosts:
            configure_ssh_user_keys(user, 'known_hosts', known_hosts)
    return json.dumps(param)


def do_cleanup(param):
    param = json.loads(param)
    for filename in ['ssh_host_key', 'ssh_host_key.pub',
                     'ssh_host_rsa_key', 'ssh_host_rsa_key.pub',
                     'ssh_host_dsa_key', 'ssh_host_dsa_key.pub']:
        filepath = os.path.join('/etc/ssh', filename)
        if os.path.exists(filepath):
            os.unlink(filepath)
    for userinfo in pwd.getpwall():
        homedir = userinfo[5]
        sshdir = os.path.join(homedir, '.ssh')
        if os.path.exists(sshdir):
            shutil.rmtree(sshdir)
    return json.dumps(param)


if __name__ == '__main__':
    main(do_enumerate, {'configure': do_configure, 'cleanup': do_cleanup})
