#!/usr/libexec/platform-python
#
# Copyright (c) 2022, Oracle and/or its affiliates.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# This code is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 only, as
# published by the Free Software Foundation.
#
# This code is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# version 2 for more details (a copy is included in the LICENSE file that
# accompanied this code).
#
# You should have received a copy of the GNU General Public License version
# 2 along with this work; if not, see <https://www.gnu.org/licenses/>.
#
# Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
# or visit www.oracle.com if you need additional information or have any
# questions.

"""
Tool to watch CPU utilization and execute user provided commands when
configured thresholds are reached.
"""
import argparse
import datetime
import logging
import logging.handlers
import os
import shlex
import signal
import subprocess
import sys
import time

from typing import Dict, List, Mapping, Sequence

# A CPU stats snapshot
#
# {
#   "cpu": [s1, s2, ...],
#   "cpu0": [...],
#    ...
#   "cpuN": [...]
# }
Snapshot = Dict[str, List[int]]


def take_snapshot() -> Snapshot:
    """Take CPU stats snapshot from /proc/stat."""
    with open("/proc/stat", "r") as desc:
        # Read file, split content in lines, drop lines that don't start with
        # "cpu", split the remaining lines in tokens separated by spaces.  See
        # procfs(5) for a description of the format of the file.
        data = (
            line.split() for line in desc.readlines() if line.startswith("cpu")
        )

        # Return mapping { "all": [s1, s2, ...], "0": [s1, s2, ...], ... }
        # where s1, s2, ... are integers.
        return {
            "all" if l[0] == "cpu" else l[0][3:]: [int(x) for x in l[1:]]
            for l in data
        }


def compute_percent_util(ticks1: List[int], ticks2: List[int]) -> List[float]:
    """Compute CPU utilization (in %) from two CPU stats snapshots"""
    elapsed_ticks = sum(ticks2) - sum(ticks1)

    return [
        ((ticks2[i] - ticks1[i]) * 100) / elapsed_ticks
        for i in range(len(ticks1))
    ]


CpuUtil = Dict[str, List[float]]


def snap_percent_util(snap1: Snapshot, snap2: Snapshot) -> CpuUtil:
    """Compute CPU utilization (in %) between to snapshots."""
    return {
        k: compute_percent_util(snap1[k], snap2[k])
        for k in snap1
        if k in snap2  # only consider CPUs that appear in both snapshots
    }


# CPU stat IDs
CpuStat = int
USER = 0
NICE = 1
SYSTEM = 2
IDLE = 3
IOWAIT = 4
IRQ = 5
SOFTIRQ = 6
STEAL = 7
GUEST = 8
GUEST_NICE = 9


def cpu_thresholds_reached(
        cpu_utils: Sequence[Sequence[float]],
        stat_thresholds: Mapping[CpuStat, int],
        nr_cpus: int) -> bool:
    """Compute CPU stats that reached the specified thresholds"""
    for stats in cpu_utils:
        for stat, threshold in stat_thresholds.items():
            if stats[stat] < threshold:
                break  # current CPU didn't reach all stat thresholds
        else:
            nr_cpus = nr_cpus - 1  # current CPU reached all thresholds

    return nr_cpus <= 0


def check_fs_util(dir_path: str, max_fs_util: int) -> None:
    """Check that there is enough space in a filesystem.

    Exit program if the filesystem containing dir_path surpasses max_fs_util
    disk usage.
    """
    fs_info = os.statvfs(dir_path)
    space_used = (fs_info.f_blocks - fs_info.f_bfree) * 100 / fs_info.f_blocks

    if space_used > max_fs_util:
        logging.error(
            ("Disk space usage in filesystem containing '%s' has surpassed "
             "the maximum allowed.  current=%.2f; max_allowed=%d\n"
             "Aborting"),
            dir_path, space_used, max_fs_util)
        sys.exit(1)


def exec_action_cmds(cmds: Sequence[str], base_dir: str) -> None:
    """Execute action commands.

    The commands are executed in parallel and this function doesn't return
    until all commands finish.
    Each command is executed in it's own directory, named after the command,
    replacing every space and slash (/) by two underscore (_).  E.g. command
    "echo hello" will be executed in directory "echo__hello/".  The command's
    stdout/stderr are redirected to file "output" in its CWD.

    NOTE: The commands' exit status is not checked/logged/processed.
    """
    logging.info("Executing actions in directory '%s'...", base_dir)

    def exec_cmd(cmd):
        cmd_dir = os.path.join(
            base_dir, cmd.replace(" ", "__").replace("/", "__"))

        # Ensure we use a unique directory, just append underscores until we
        # form a unique directory name.  Existing directories can be found when
        # the user specifies repeated action commands.
        while os.path.isdir(cmd_dir):
            cmd_dir = cmd_dir + "_"

        os.makedirs(cmd_dir)

        with open(os.path.join(cmd_dir, "output"), "x") as log_fd:
            return subprocess.Popen(
                cmd, shell=True, close_fds=True, stdout=log_fd, stderr=log_fd,
                stdin=subprocess.DEVNULL, cwd=cmd_dir)

    processes = [exec_cmd(c) for c in cmds]

    for proc in processes:
        proc.wait()

    logging.info("done executing actions")


def format_cpu_util(cpu_util: CpuUtil) -> str:
    """Format CPU utilization snapshot."""
    fmt = "{:>6} {:>6.2f} {:>6.2f} {:>6.2f} {:>7.2f} {:>6.2f} {:>6.2f} "\
          "{:>6.2f} {:>6.2f} {:>6.2f} {:>6.2f}"
    header = "   CPU   %usr  %nice   %sys %iowait   %irq  %soft %steal %guest"\
             " %gnice  %idle"

    stats = '\n'.join(
        fmt.format(
            cpu, st[USER], st[NICE], st[SYSTEM], st[IOWAIT], st[IRQ],
            st[SOFTIRQ], st[STEAL], st[GUEST], st[GUEST_NICE], st[IDLE])
        for cpu, st in cpu_util.items()
    )

    return header + '\n' + stats


def watch_cpu_utilization(
        interval: int,
        stat_thresholds: Mapping[CpuStat, int],
        actions: Sequence[str],
        nr_cpus: int,
        working_dir: str,
        max_fs_util: int,
        only_one_event: bool) -> None:
    """Watch for CPU utilization.

    And execute given actions when given thresholds are reached.
    """
    # pylint: disable=too-many-arguments

    logging.info("CPU utilization watch started...")
    old_timestamp = datetime.datetime.now()
    old_snapshot = take_snapshot()

    while True:
        time.sleep(interval)
        new_timestamp = datetime.datetime.now()
        new_snapshot = take_snapshot()

        # compute CPU utilization between snapshots
        snapshot_util = snap_percent_util(old_snapshot, new_snapshot)

        # When no specific # of CPUs is provided (i.e. nr_cpus <= 0), we watch
        # the cumulative stats of all CPUs in the system (stats for key "all"
        # in snapshot_util); otherwise we watch stats of all individual CPUs,
        # discarding the cumulative stats.
        if nr_cpus > 0:
            cpu_utils = tuple(
                stats for cpu, stats in snapshot_util.items() if cpu != "all")
        else:
            cpu_utils = (snapshot_util["all"], )

        # execute actions if CPU util thresholds were reached
        if cpu_thresholds_reached(
                cpu_utils, stat_thresholds, 1 if nr_cpus <= 0 else nr_cpus):
            # CPU utilization reached
            logging.info(
                ("Reached CPU utilization thresholds:\n"
                 "Snapshot: %s - %s\n"
                 "Stats:\n%s"),
                old_timestamp.strftime("%Y-%m-%dT%H:%M:%S"),
                new_timestamp.strftime("%Y-%m-%dT%H:%M:%S"),
                format_cpu_util(snapshot_util))

            check_fs_util(working_dir, max_fs_util)

            cmds_dir = os.path.join(
                working_dir, new_timestamp.strftime("%Y-%m-%dT%H-%M-%S"))
            exec_action_cmds(actions, cmds_dir)

            if only_one_event:
                break

            logging.info("continue watching...")

            # take new snapshot after actions finished executing
            new_timestamp = datetime.datetime.now()
            new_snapshot = take_snapshot()

        old_timestamp = new_timestamp
        old_snapshot = new_snapshot

    logging.info("Stop watching")


def parse_args() -> argparse.Namespace:
    """Parse CLI arguments."""
    def percentage(value):
        """Validate a percentage value (must be >= 1 and <= 100)"""
        value = int(value)

        if value < 1 or value > 100:
            raise ValueError("Percentage must be in range [1, 100]")

        return value

    def interval(value):
        """Validate snapshot interval"""
        value = int(value)

        if value < 1:
            raise ValueError("Snapshot interval must be >= 1 seconds")

        return value

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-b", dest="run_in_background", action='store_true',
        help="Whether to run continuously in background.")
    parser.add_argument(
        "-C", dest="nr_cpus", type=int, default=0,
        help=("# of CPUs to apply match criteria.  A value <= 0 means apply "
              "system wide."))
    parser.add_argument(
        "-u", dest="user_percent", type=percentage,
        help="Triggers if %%user reaches or exceeds this argument.")
    parser.add_argument(
        "-n", dest="nice_percent", type=percentage,
        help="Triggers if %%nice reaches or exceeds this argument.")
    parser.add_argument(
        "-s", dest="sys_percent", type=percentage,
        help="Triggers if %%sys reaches or exceeds this argument.")
    parser.add_argument(
        "-i", dest="iowait_percent", type=percentage,
        help="Triggers if %%iowait reaches or exceeds this argument.")
    parser.add_argument(
        "-t", dest="target_dir", default="/var/oled/syswatch",
        help=("Target directory.  The program would cd to this directory "
              "before performing any actions."))
    parser.add_argument(
        "-M", dest="max_fs_util", type=percentage, default=85,
        help=("Max Filesystem Utilization.  If the filesystem is at or above "
              "the set %%, then exit and don't take any action."))
    parser.add_argument(
        "-I", dest="interval", type=interval, default=5,
        help="Interval, in seconds, between CPU utilization snapshots.")
    parser.add_argument(
        "-a", dest="actions", required=True, action="append",
        help=("The action to perform.  This is mandatory.  If a more complex "
              "action is needed this can point to a script.  This option can "
              "be specified multiple times, in which case all the action "
              "commands will be executed in parallel."))

    args = parser.parse_args()

    if not (args.user_percent or args.nice_percent or args.sys_percent
            or args.iowait_percent):
        parser.error("At least one of -u, -n, -s, -i must be specified")

    return args


def setup_logging(working_dir: str) -> None:
    """Setup application logging."""
    formatter = logging.Formatter(
        fmt="%(asctime)s %(levelname)s - %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S%z")

    log_file = os.path.join(working_dir, "syswatch.log")
    file_hdlr = logging.FileHandler(log_file, mode="w")
    file_hdlr.setFormatter(formatter)

    logger = logging.getLogger()

    logger.addHandler(file_hdlr)

    # also log to stdout if it's a tty
    if sys.stdout.isatty():
        stdout_hdlr = logging.StreamHandler(sys.stdout)
        stdout_hdlr.setFormatter(formatter)
        logger.addHandler(stdout_hdlr)

    logger.setLevel(logging.INFO)

    logging.info("Log file: %s", log_file)


def main() -> None:
    """Main function."""
    args = parse_args()

    if os.getuid() != 0:
        logging.error("This script must be run as root.  Aborting.")
        sys.exit(1)

    # create working dir and change to it
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    working_dir = os.path.join(
        args.target_dir, f"syswatch_{timestamp}_{os.getpid()}")
    os.makedirs(working_dir)
    os.chdir(working_dir)

    setup_logging(working_dir)
    logging.info(" ".join(map(shlex.quote, sys.argv)))

    config_str = (
        f"\tRun continuously: {args.run_in_background}\n"
        f"\t# CPUs: {'ALL' if args.nr_cpus <= 0 else args.nr_cpus}\n"
        f"\t%user: {args.user_percent}\n"
        f"\t%nice: {args.nice_percent}\n"
        f"\t%sys: {args.sys_percent}\n"
        f"\t%iowait: {args.iowait_percent}\n"
        f"\tWorking dir: {working_dir}\n"
        f"\tMax FS utilization: {args.max_fs_util}\n"
        f"\tSnapshot interval (sec): {args.interval}\n"
        f"\tActions: {args.actions}"
    )
    logging.info("Config:\n %s", config_str)

    check_fs_util(args.target_dir, args.max_fs_util)

    cpu_stat_thresholds = {}

    if args.user_percent:
        cpu_stat_thresholds[USER] = args.user_percent
    if args.nice_percent:
        cpu_stat_thresholds[NICE] = args.nice_percent
    if args.sys_percent:
        cpu_stat_thresholds[SYSTEM] = args.sys_percent
    if args.iowait_percent:
        cpu_stat_thresholds[IOWAIT] = args.iowait_percent

    watch_cpu_utilization(
        args.interval, cpu_stat_thresholds, args.actions, args.nr_cpus,
        working_dir, args.max_fs_util,
        only_one_event=(not args.run_in_background))


def exit_signal_handler(*_args, **_kwargs) -> None:
    """Signal handler that exits the program."""
    logging.error("Interrupted")
    sys.exit(1)


if __name__ == "__main__":
    # gracefully handle common termination signals
    signal.signal(signal.SIGINT, exit_signal_handler)
    signal.signal(signal.SIGTERM, exit_signal_handler)

    main()
    logging.info("Finished")
