#!/usr/bin/env python3
"""This script queries the collector at a fixed interval, and takes the difference in some runtime totals to
determine the change in runtime for each performance category. It can then attempt to calculate the percentage
of CPU time each category is using via the DaemonCoreDutyCycle. These statistics are then either printed to the
console or displayed as a flame graph in a GUI window.

The structure of the script has some main parts:
  - Dict Operations: Pure functions that manipulate dictionaries (map, reduce, etc.)
  - Stats Operations: Functions that create and modify "diff dictionaries" which are the difference between two samples of runtimes
  - Print functions: Functions that print the statistics to the console
  - GUI functions: Functions that draw the flame graph in a GUI window; these are self-contained in the FlameGraph class
"""

import htcondor2 as htcondor
from classad2 import ClassAd
import argparse
import os
import sys
import copy
import math
import collections
import signal
import datetime as dt
from time import sleep
from collections import defaultdict
from typing import List, Dict, Callable, Tuple

# Conditionally import tkinter for GUI mode
TKINTER = False
try:
    import tkinter as tk
    from tkinter import ttk
    TKINTER = True
except ImportError:
    pass

COMMAND_PREFIX = "DIAG_CCS"  # Prefix for command statistics
REAP_PREFIX = "DIAG_CRS"  # Prefix for command reaper statistics
FSYNC_PREFIX = "DIAG_CFS"  # Prefix for fsync statistics
TIMER_PREFIX = "DCTimer_"  # Prefix for timer statistics
TIMER_SUFFIX = "Runtime"  # Suffix for timer statistics
DANGER_THRESHOLD = 20  # Threshold for highlighting a user in red (% runtime)
# Threshold for highlighting a user in yellow (% runtime)
WARNING_THRESHOLD = 10


# Define terminal colors for printing
class colors:
    WARNING = "\033[93m"
    DANGER = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"


# Cleanly exit on SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, lambda x, y: exit(0))

# Hardcoded kwargs for querying the schedd
kwargs = {
    "daemon_type": htcondor.DaemonTypes.Schedd,
    "statistics": "All:2",
}


def get_name() -> str:
    """Returns the name of the schedd, or exits if it can't be found."""
    collector = htcondor.Collector()
    machine = htcondor.param["FULL_HOSTNAME"]
    # Query for the schedd ad
    ads = collector.query(
        htcondor.AdTypes["Schedd"],
        constraint=f"Machine == {machine}",
        projection=["Name"],
    )
    # if that doesn't work, remove the constraint and try again
    if len(ads) == 0:
        ads = collector.query(htcondor.AdTypes["Schedd"], projection=["Name"])
    # Still doesn't work -> we have an error
    if len(ads) == 0:
        print(f"Error: Couldn't find hostname for schedd")
        sys.exit(1)
    return ads[0]["Name"]


def map_dict(d: Dict, lambda_: Callable) -> Dict:
    d = copy.deepcopy(d)
    """Maps a lambda function over a (possibly nested) dictionary
    d: The dictionary to map over
    lambda_: The function to apply to each value
    returns: The modified dictionary
    """
    for key, value in d.items():
        if isinstance(value, dict):
            d[key] = map_dict(value, lambda_)
        else:
            d[key] = lambda_(value)
    return d


def bimap_dict(d1: Dict, d2: Dict, lambda_: Callable, fail_: Callable = None) -> Dict:
    """Maps a lambda function over two (possibly nested) dictionaries, combining them.
    Searches through the keys of d1, applying fail if the key is not in d2.
    d1: The first dictionary to map over
    d2: The second dictionary to map over
    lambda_: The function to apply to each pair of values
    fail_: The lambda function to apply when a key is missing from d2
    returns: A combined dictionary
    """
    d1 = copy.deepcopy(d1)
    d2 = copy.deepcopy(d2)
    res = {}
    for key, value in d1.items():
        # If the key is not in d2, apply fail function if it exists
        if key not in d2:
            if fail_ is not None:
                res[key] = fail_(value)
            continue
        # Otherwise apply lambda function to dict or value accordingly
        if isinstance(value, dict):
            res[key] = bimap_dict(value, d2[key], lambda_, fail_)
        else:
            res[key] = lambda_(value, d2[key])
    return res


def reduce_dict(dicts: List[Dict], base: Dict, lambda_: Callable, fail_: Callable = None) -> Dict:
    """Acts as a reduce function for dictionaries, combining all of them with a lambda function.
    dicts: A list of dicts to reduce
    base: The starting value of the accumulated dict
    lambda_: A two-valued lambda function that is called on the accumulator and each dict in dicts
    fail_: A failsafe lambda function that is used when a key is missing
    returns: The accumulated dictionary
    """
    dicts = copy.deepcopy(dicts)
    base = copy.deepcopy(base)
    for i in range(len(dicts)):
        base = bimap_dict(dicts[i], base, lambda_, fail_)
    return base


def add_totals_dict(d: Dict) -> Dict:
    """Adds a total key to dict fields recursively.
    dict: The dictionary to add totals to
    returns: The modified dictionary
    """
    res = {}
    # Base case, flat dictionary. Add total key and return
    if not any(isinstance(value, dict) for value in d.values()):
        res = {key: value for key, value in d.items()}
        res["total"] = sum(d.values())
        return res

    # Recursive case, add totals to sub-dictionaries
    else:
        for key, value in d.items():
            if isinstance(value, dict):
                res[key] = add_totals_dict(value)
        res["total"] = sum([value["total"] for value in res.values()])

    return res


def invert_dict(d: Dict) -> Dict:
    """Inverts stats from (user->command) to (command->user)
    d: The dictionary to invert
    returns: The inverted dictionary
    """
    new = defaultdict(lambda: defaultdict(lambda: 0))
    for category, table in d.items():
        if category == "Commands":
            for user, commands in table.items():
                for command, value in commands.items():
                    new[command][user] = value
            d["Commands"] = new
    return d


def get_fields(
    d: Dict,
    ad: ClassAd,
    title: str,
    prefix: str = "",
    suffix: str = "",
    splituser: bool = False,
) -> Dict:
    """Populates a dictionary with statistics from a given prefix in-place.
    dict: The dictionary to populate
    ad: The ad to pull statistics from
    prefix: The prefix to search for
    suffix: The suffix to search for
    title: The top-level key to add to the dictionary
    splituser: Whether the ad needs to be split into user and value
    returns: The modified dictionary
    """
    names = [key[len(prefix) :] for key in ad.keys() if key.startswith(prefix) and key.endswith(suffix)]
    # If there is user info on field we need to split it off
    if splituser:
        for name in names:
            user, value = name.split("_", 1)
            d[title][user][value] = ad[prefix + name]
    else:
        for name in names:
            d[title][name] = ad[prefix + name]

    return d


def get_stats() -> Tuple[Dict, dt.datetime]:
    """Queries the collector for command statistics on the schedd.
    returns: A tuple of the statistics and the time the statistics were taken
    """
    try:
        collector = htcondor.Collector()
        ad = collector.directQuery(**kwargs)
    except Exception as e:
        print(f"Error: Couldn't query collector: {e}")
        sys.exit(1)

    # Get useful metadata
    dcdc = ad["DaemonCoreDutyCycle"]
    cpu_time = dcdc * args.window
    server_time = ad["MyCurrentTime"]
    format_time = dt.datetime.fromtimestamp(server_time)

    # Lets us make entries arbitrarily deep
    def recursive_defaultdict():
        return defaultdict(recursive_defaultdict)

    stats = recursive_defaultdict()
    # Get the statistics for each category
    if args.c or args.all:
        get_fields(stats, ad, "Commands", prefix=COMMAND_PREFIX, splituser=True)
    if args.r or args.all:
        get_fields(stats, ad, "Reapers", prefix=REAP_PREFIX)
    if args.t or args.all:
        get_fields(stats, ad, "Timers", prefix=TIMER_PREFIX, suffix=TIMER_SUFFIX)
    if args.f or args.all:
        get_fields(stats, ad, "Fsync", prefix=FSYNC_PREFIX)

    # Invert table before we add extra top-level keys
    if args.invert:
        stats = invert_dict(stats)

    stats |= {
        "CPU": cpu_time,
        "DCDC": dcdc,
    }  # Shove ancillary data in with the command statistics, kinda messy but we can filter it later
    return stats, format_time


def get_diffs(old: Dict, new: Dict) -> Tuple[Dict, Dict]:
    """Calculates the difference between two sets of statistics.
    old: The previous set of statistics
    new: The current set of statistics
    returns: A tuple of the new statistics and the difference between the two sets
    """
    global CPU_OVERFLOW
    CPU_OVERFLOW = False
    if old is None:
        print("Waiting on second sample...")
        return new, {}
    else:
        # Filter out metadata
        old_stats = {key: value for key, value in old.items() if isinstance(value, dict)}
        new_stats = {key: value for key, value in new.items() if isinstance(value, dict)}

        d = bimap_dict(old_stats, new_stats, lambda x, y: y - x, fail_=lambda x: x)
        if args.prometheus:
            write_to_prometheus(d)
        diffs.append(d)
        diff = average_diffs(diffs)
        # Renormalize the totals if they go over 100% of CPU time
        if diff["total"] > new["CPU"]:
            new["CPU"] = diff["total"]
            CPU_OVERFLOW = True
        # Use the total CPU time to calculate the "Other" category then discard
        diff["Other"] = {"total": new["CPU"] - diff["total"]}
        del diff["total"]
        if args.rel:
            diff = map_dict(diff, lambda x: x / new["CPU"] * 100)
        else:
            diff = map_dict(diff, lambda x: x / args.window * 100)
        return new, diff


def write_to_prometheus(diff: Dict):
    """Transforms a diff dictionary into prometheus format. See
    https://github.com/prometheus/docs/blob/main/content/docs/instrumenting/exposition_formats.md
    diff: The dictionary to write
    """
    diff = copy.deepcopy(diff)
    diff = add_totals_dict(diff)

    def traverse(d, depth=0):
        for key, value in d.items():
            clean_key = key.replace("@", "_").replace(" ", "_").replace("-", "_")
            if isinstance(value, dict):
                # Prometheus doesn't like special characters in metric names
                f.write(f"# TYPE {clean_key} gauge\n")
                f.write(f'{clean_key} {value["total"]}\n')
                traverse(d[key], depth + 1)
            else:
                f.write(f"# TYPE {clean_key} gauge\n")
                f.write(f"{clean_key} {value}\n")

    with open(args.prometheus, "w") as f:
        traverse(diff)


def average_diffs(diffs: List[Dict]) -> Dict:
    """Averages a list of raw diff dictionaries with weights
    diffs: The diffs to average
    returns: The averaged dict
    """
    if len(diffs) == 0:
        return {}
    weights = list(reversed([args.p**i for i in range(len(diffs))]))
    scaled_diffs = []
    for i in range(len(diffs)):
        scaled_diffs.append(map_dict(diffs[i], lambda x: x * weights[i]))
    summed_diffs = reduce_dict(scaled_diffs, {}, lambda x, y: x + y, fail_=lambda x: x)
    res = map_dict(summed_diffs, lambda x: x / sum(weights))
    # Add metadata
    res = add_totals_dict(res)
    return res


def print_stat(key: str, value: float, spacing: int, bullet: str, color: bool = False):
    """Prints a single statistic with color coding
    key: The key to print
    value: The value to print
    spacing: The number of tabs to indent
    bullet: The bullet to print before the key
    color: Whether to color code the output
    """
    if value <= args.limit:
        return
    if not color:
        print(f"\t{' ' * spacing} {bullet} {key}: {value:.2f}%")
        return

    print_color = ""
    whitespace = "  " * spacing
    if value >= DANGER_THRESHOLD:
        print_color = colors.DANGER
    elif value >= WARNING_THRESHOLD:
        print_color = colors.WARNING
    print(print_color, f"{whitespace} {bullet} {key}: {value:.2f}%", colors.ENDC)


def print_dict(
    d: Dict,
    depth: int,
    bullets: List[str],
    color_flags: List[bool],
    sort_ascending: bool = False,
):
    """Recursively prints a dictionary of statistics
    d: The dictionary to print
    depth: The current depth of the recursion
    bullets: A list of bullets to print at each depth
    color_flags: A list of flags to determine whether to color code the output at that depth
    """
    subdicts = [i for i in d.items() if isinstance(i[1], dict)]
    if not subdicts:
        for key, value in sorted(d.items(), key=lambda x: x[1], reverse=not sort_ascending):
            if key == "total":
                continue
            print_stat(
                key,
                value,
                depth,
                bullets[depth % 3],
                color_flags[min(depth, len(color_flags) - 1)],
            )
    else:
        for key, value in sorted(subdicts, key=lambda x: x[1]["total"], reverse=not sort_ascending):
            if isinstance(value, dict):
                print_stat(
                    key,
                    value["total"],
                    depth,
                    bullets[depth % 3],
                    color_flags[min(depth, len(color_flags) - 1)],
                )
                print_dict(value, depth + 1, bullets, color_flags)


def print_stats(diff: Dict, stats: Dict, time: dt.datetime):
    """Prints the user statistics
    diff: The difference between the two sets of statistics
    stats: The current set of statistics (used for CPU time, DCDC, etc.)
    time: The time the current set of statistics were taken
    """
    BULLETS = ["*", "-", ">"]
    if not diff:
        return
    else:
        # Print header
        if not args.silent:
            os.system("clear")
        print(f"{time}" + "-" * 61)  # Timestamp is 19 characters until 10000AD
        print(f"Sample Window: {args.window:7.0f} s")
        print(f"CPU Time: {stats['CPU']:12.3f} s")
        if CPU_OVERFLOW:
            print(f"CPU Usage: {(stats['CPU'] / args.window * 100):12.2f}%")
        else:
            print(f"CPU Usage: {(stats['DCDC'] * 100):12.2f}%")
        print(f"Windows Averaged: {len(diffs):3}/{args.m}")
        print("-" * 80)
        if CPU_OVERFLOW:
            print(
                colors.DANGER,
                "WARNING: Summed statistics exceeded CPU time! These values are renormalized!",
                colors.ENDC,
            )
            print(
                colors.DANGER,
                "This typically happends when traffic spikes faster then CPU Usage can update.",
                colors.ENDC,
            )
        order = sorted(
            [(k, v) for (k, v) in diff.items() if k != "Other"],
            key=lambda x: x[1]["total"],
            reverse=not args.ascending,
        )
        for key, table in order:
            if key == "total" or table["total"] <= args.limit:
                continue
            print(
                colors.BOLD,
                f"{key}: {(table['total'] * stats['CPU'] / 100 * 1e9):.1e}ns ({table['total']:.2f}%)",
                colors.ENDC,
            )
            print_dict(
                table,
                depth=0,
                bullets=BULLETS,
                color_flags=[True, True, False],
                sort_ascending=args.ascending,
            )
        # Always put "Other" on bottom
        print(colors.BOLD, f"Other ({diff['Other']['total']:.2f}%)", colors.ENDC)
        print("-" * 80)


def parse_args() -> argparse.Namespace:
    """Parses command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-w",
        "--window",
        default=30,
        type=int,
        help="Seconds between updates (default 30)",
    )
    parser.add_argument("-n", "--name", default=None, help="Query a specific hostname (default auto)")
    parser.add_argument(
        "-l",
        "--limit",
        default=0.001,
        type=float,
        help="Minimum percentage to print statistics (default 0.01)",
    )
    parser.add_argument(
        "-s",
        "--silent",
        default=None,
        action="store_true",
        help="Don't print stats to stdout (default off)",
    )
    parser.add_argument(
        "-i",
        "--invert",
        action="store_true",
        help="Invert the user/command ordering (default off)",
    )
    parser.add_argument(
        "-a",
        "--ascending",
        action="store_true",
        help="Sort values in ascending order (default off)",
    )
    parser.add_argument("-c", action="store_true", help="Show command statistics")
    parser.add_argument("-r", action="store_true", help="Show reaper statistics")
    parser.add_argument("-t", action="store_true", help="Show timer statistics")
    parser.add_argument("-f", action="store_true", help="Show fsync statistics")
    parser.add_argument("--all", action="store_true", help="Show all statistics (default on)")
    parser.add_argument(
        "--rel",
        action="store_true",
        help="Show percentages relative to CPU time (default off)",
    )
    parser.add_argument("-g", action="store_true", help="Use a GUI version of the monitor")
    parser.add_argument("-m", default=20, type=int, help="How many windows to average over (default 20)")
    parser.add_argument(
        "-p",
        default=0.9,
        type=float,
        help="Factor to penalize old data (default 0.9)\n Ex. 3rd last window has weight 0.9^3",
    )
    parser.add_argument(
        "--prometheus",
        type=str,
        help="Output to the specified file in prometheus format (default none)",
    )
    args = parser.parse_args()
    # If no categories are specified, default to all
    if not any([args.c, args.r, args.t, args.f, args.all]):
        args.all = True
    # The GUI should default to relative percentages
    if args.g:
        args.rel = True
    # Penalty factor must be in (0, 1)
    if args.p <= 0 or args.p >= 1:
        print("Invalid penalty factor! Must be in range (0, 1)")
        sys.exit(1)
    return args


def wait(seconds: float):
    """Waits for a number of seconds and prints the remaining time to console."""
    if seconds <= 0:
        print("Invalid window time! Must be a positive integer.")
        sys.exit(1)
    for i in range(seconds):
        print(f"Waiting {seconds - i} seconds    ", end="\r")  # Spaced so carriage return works on 10->9, etc.
        sleep(1)

if TKINTER:
    class FlameGraph:
        def __init__(self, root):
            self.root = root
            self.root.title("condor_diagnostics")

            # Set up styles
            style = ttk.Style()
            style.theme_use('clam')
            style.configure('TButton', font=('Arial', 10), padding=5)
            style.configure('TLabel', font=('Arial', 10))

            # Main frame
            main_frame = ttk.Frame(root, padding="10")
            main_frame.pack(fill=tk.BOTH, expand=True)

            # Canvas
            self.canvas = tk.Canvas(main_frame, width=800, height=600, bg="white")
            self.canvas.pack(fill=tk.BOTH, expand=True)

            # Control frame
            control_frame = ttk.Frame(main_frame)
            control_frame.pack(fill=tk.X, pady=10)

            # Labels
            self.time_label = ttk.Label(control_frame, text="0s until next sample")
            self.time_label.pack(side=tk.LEFT, padx=(0, 10))

            self.diffs_label = ttk.Label(control_frame, text=f"{len(diffs)}/{args.m} windows averaged")
            self.diffs_label.pack(side=tk.LEFT, padx=(0, 10))

            self.tooltip_label = ttk.Label(control_frame, text="")
            self.tooltip_label.pack(side=tk.LEFT)

            # Button frame
            button_frame = ttk.Frame(main_frame)
            button_frame.pack(fill=tk.X, pady=10)

            # Buttons
            self.button = ttk.Button(button_frame, text="Stop", command=self.toggle)
            self.button.pack(side=tk.LEFT, padx=(0, 5))

            self.back_button = ttk.Button(button_frame, text="Back", command=self.back_one)
            self.back_button.pack(side=tk.LEFT, padx=5)

            self.top_button = ttk.Button(button_frame, text="Top", command=self.start_draw)
            self.top_button.pack(side=tk.LEFT, padx=5)

            # Other attributes
            self.max_width = root.winfo_screenwidth() * 0.9
            self.max_height = root.winfo_screenheight() * 0.75
            self.win_width = root.winfo_screenwidth()
            self.win_height = root.winfo_screenheight()
            self.colors = ["chocolate1", "firebrick1", "goldenrod1"]
            self.old = None
            self.running = True
            self.window = args.window
            self.time_remaining = args.window

            # Initial updates
            self.update()
            self.update_time()

        def toggle(self):
            """Toggles the running state of the flame graph"""
            self.running = not self.running
            # Clear old data if we're restarting, otherwise runtimes will scale with time paused.
            if self.running:
                self.old = None
                self.time_remaining = self.window
            self.button.config(text="Stop" if self.running else "Start")

        def update(self):
            """Updates the flame graph with the latest statistics"""
            if self.running:
                stats, time = get_stats()
                self.old, self.diff = get_diffs(self.old, stats)
                if self.diff:
                    d = {}
                    d["Total"] = copy.deepcopy(self.diff)
                    d["total"] = 100
                    d["Total"]["total"] = 100
                    d["Total"]["Other"] = d["Total"]["Other"]["total"]  # Other is a dict so we need to flatten it
                    self.diff = d
                    self.canvas.delete("all")
                    self.diffs_label.config(text=f"{len(diffs)}/{args.m} windows averaged")
                    self.max_depth = self.get_depth(self.diff)
                    self.box_height = self.max_height / self.max_depth / 1.75  # Cool magic number
                    self.start_draw([])

        def update_time(self):
            """Updates the time remaining until the next sample every second"""
            if self.running:
                self.time_remaining -= 1
                self.time_label.config(text=f"{self.time_remaining}s until next sample")
                if self.time_remaining == 0:
                    self.update()
                    self.time_remaining = self.window
            self.root.after(1000, self.update_time)

        def get_depth(self, data) -> int:
            """Recursively calculates the depth of a dictionary
            data: The dictionary to calculate the depth of
            """
            if isinstance(data, dict):
                return 1 + max(self.get_depth(value) for value in data.values())
            else:
                return 1

        def enter_box(self, rect):
            """Displays a tooltip and highlights the box on mouse enter
            rect: The box entered, used for the associated tag data
            """
            tags = self.canvas.gettags(rect)
            if len(tags) == 0:
                return
            path = tags[0].split(":")
            name, val = path[-2], path[-1]
            self.tooltip_label.config(text=f"{name}: {float(val):.2f}%")
            self.old_color = self.canvas.itemcget(rect, "fill")
            self.canvas.itemconfig(rect, fill="cyan")

        def leave_box(self, rect):
            """Clears the tooltip and resets the color of the box on mouse leave
            rect: The box left, used to restore its color
            """
            self.tooltip_label.config(text="")
            self.canvas.itemconfig(rect, fill=self.old_color)

        def click_box(self, rect):
            """Redraws the flame graph from the clicked box
            rect: The boxed clicked, used for its tag data
            """
            tags = self.canvas.gettags(rect)
            if len(tags) == 0:
                return
            path = tags[0].split(":")[:-1]
            d = self.diff
            for k in path:
                d = d[k]
            if not isinstance(d, dict):
                return
            self.current_path = path
            self.canvas.delete("all")
            self.start_draw(path)

        def back_one(self):
            """Goes back one layer in the flame graph"""
            if len(self.current_path):
                self.canvas.delete("all")
                self.current_path = self.current_path[:-1]
                self.start_draw(self.current_path)

        def draw_box(self, x, y, width, height, keystring, value, depth):
            """Draws a single box representing a statistic
            x: The x-coordinate to draw at
            y: The y-coordinate to draw at
            width: The width of the box
            height: The height of the box
            keystring: The keystring data for this box, used for redrawing on click
            value: The statistic value to display in the box
            """
            name = keystring.split(":")[-1]
            rect = self.canvas.create_rectangle(
                x,
                y,
                x + width,
                y + height,
                fill=self.colors[depth % 3],
                tags=f"{keystring}:{value}",
            )
            self.canvas.tag_bind(rect, "<Button-1>", lambda _: self.click_box(rect))
            self.canvas.tag_bind(rect, "<Enter>", lambda _: self.enter_box(rect))
            self.canvas.tag_bind(rect, "<Leave>", lambda _: self.leave_box(rect))

            if width > 45:
                chars_to_print = math.floor(width / 15)
                text_to_print = name[:chars_to_print] + ".." if chars_to_print < len(name) else name
                name_text = self.canvas.create_text(
                    x + width / 2,
                    y + height / 2,
                    anchor=tk.CENTER,
                    text=text_to_print,
                    fill="black",
                )
                self.canvas.tag_bind(name_text, "<Enter>", lambda _: self.enter_box(rect))  # Kind of hacky but we need text to be bound to the same event as the box
                self.canvas.tag_bind(name_text, "<Leave>", lambda _: self.leave_box(rect))
                self.canvas.tag_bind(name_text, "<Button-1>", lambda _: self.click_box(rect))
                if chars_to_print > 6:
                    val_text = self.canvas.create_text(
                        x + width / 2,
                        y + height / 2 + height / 6,
                        anchor=tk.CENTER,
                        text=f"{value:.2f}%",
                        fill="black",
                    )
                    self.canvas.tag_bind(val_text, "<Enter>", lambda _: self.enter_box(rect))  # Same here
                    self.canvas.tag_bind(val_text, "<Leave>", lambda _: self.leave_box(rect))
                    self.canvas.tag_bind(name_text, "<Button-1>", lambda _: self.click_box(rect))

        def start_draw(self, keys=[]):
            """Starts the drawing process
            keys: A list of dictionary keys that give the path of the "base" box
            """
            # Have to deepcopy here to preserve original diff
            d = copy.deepcopy(self.diff)
            ks = ""
            for k in keys[:-1]:
                d = d[k]
                ks += k + ":"
            # If we're not at top level we need to renormalize to the "base"
            if keys:
                d = {
                    keys[-1]: d[keys[-1]],
                }
                d["total"] = d[keys[-1]]["total"]
            self.draw(
                d,
                (self.win_width - self.max_width) / 2,
                self.max_height,
                self.max_width,
                keystring=ks,
            )

        def draw(self, data, x, y, width, depth=0, keystring=""):
            """Recursively draws a flame graph from a dictionary of statistics
            data: The dictionary to draw
            x: The x-coordinate to start drawing at
            y: The y-coordinate to start drawing at
            width: The width of the drawable space, should decrease with each iteration
            depth: The current depth of the recursion
            keystring: The current keys that have been traversed to get to this point delimited by ':'
            """

            # Draw the box and further sub-boxes
            total = data["total"]
            for name, value in data.items():
                # Item is a dict so we need to draw a box and recurse
                if isinstance(value, dict):
                    if value["total"] < args.limit:
                        continue
                    cur_width = width * (value["total"] / total)  # Proportion of total width
                    self.draw_box(
                        x,
                        y - self.box_height,
                        cur_width,
                        self.box_height,
                        keystring + name,
                        value["total"] / total * 100,
                        depth,
                    )
                    self.draw(
                        value,
                        x,
                        y - self.box_height,
                        cur_width,
                        depth + 1,
                        keystring + name + ":",
                    )
                # Item is a value so we just draw a box
                else:
                    if name == "total" or total < args.limit:
                        continue
                    cur_width = width * (value / total)
                    self.draw_box(
                        x,
                        y - self.box_height,
                        cur_width,
                        self.box_height,
                        keystring + name,
                        value / total * 100,
                        depth,
                    )
                # Move the x-coordinate over by the width of the box we just drew
                x += cur_width


if __name__ == "__main__":
    args = parse_args()
    if args.silent:
        sys.stdout = open(os.devnull, "w")

    # Get the name of the schedd
    if not args.name:
        kwargs["name"] = get_name()
    else:
        kwargs["name"] = args.name
    print(f"Monitoring {kwargs['name']}...")

    old = None
    diffs = collections.deque(maxlen=args.m)

    # Graphical mode
    if args.g:
        if not TKINTER:
            print("Error: tkinter not installed! Please install it to use the GUI.")
            sys.exit(1)
        root = tk.Tk()
        root.title("condor_diagnostics")
        root.geometry(f"{root.winfo_screenwidth()}x{root.winfo_screenheight()}")
        graph = FlameGraph(root)
        root.mainloop()
    # Console mode
    else:
        while True:
            stats, time = get_stats()
            old, diff = get_diffs(old, stats)
            print_stats(diff, old, time)
            wait(args.window)
