#!/usr/bin/env python3

import os
import re
import time
import signal
from functools import partial
from multiprocessing import Process
import multiprocessing
try:
    import queue as Queue
except ImportError:
    import Queue

# credmon library in libexec
import htcondor2 as htcondor
import sys
sys.path.append(htcondor.param.get('libexec', '/usr/libexec/condor'))

try:
    from credmon.CredentialMonitors.OAuthCredmon import OAuthCredmon
except ImportError:
    OAuthCredmon = None
try:
    from credmon.CredentialMonitors.LocalCredmon import LocalCredmon
except ImportError:
    LocalCredmon = None
try:
    from credmon.CredentialMonitors.ClientCredmon import ClientCredmon
except ImportError:
    ClientCredmon = None
try:
    from credmon.CredentialMonitors.VaultCredmon import VaultCredmon
except ImportError:
    VaultCredmon = None

from credmon.utils import parse_args, setup_logging, get_cred_dir, drop_pid, credmon_incomplete, credmon_complete, create_credentials


def signal_handler(logger, send_queue, signum, frame):
    """
    Catch signals. Use SIGHUP as a sleep interrupt.
    Any other signals should exit the program.
    """
    if signum == signal.SIGHUP:
        logger.info('Got SIGHUP: Triggering READ of Credential Directory')
        send_queue.put(False, block=False)
        try:
            logger.debug(f"Approximate size of queue is now {send_queue.qsize()}")
        except NotImplementedError:
            pass
        return
    exit_msg = f"Got signal {signum} at frame {frame}, terminating."
    send_queue.put(True, block=False)
    logger.info(exit_msg)
    try:
        logger.debug(f"Approximate size of queue is now {send_queue.qsize()}")
    except NotImplementedError:
        pass
    sys.exit(0)

def main():

    # As of python 3.14, the default start method for multiprocessing
    # on Unix systems is 'forkserver', which breaks the credmon,
    # because we assume that environment variables and the 
    # htcondor config knob settings are inherited by child processes.
    # So we explicitly set it back to 'fork'.
    multiprocessing.set_start_method('fork')

    # Set up logging and the credential directory
    args = parse_args()
    htcondor.set_subsystem(args.daemon_name, htcondor.SubsystemType.Daemon)
    logger = setup_logging(**vars(args))
    cred_dir = get_cred_dir(cred_dir = args.cred_dir)
    logger.info('Starting condor_credmon and registering signals')

    # Try to create the signing credentials for the local credmon.
    create_credentials(args)

    # create queue to end alerts to child to renew
    send_queue = multiprocessing.Queue()

    # catch signals
    signal.signal(signal.SIGHUP, partial(signal_handler, logger, send_queue))
    signal.signal(signal.SIGTERM, partial(signal_handler, logger, send_queue))
    signal.signal(signal.SIGINT, partial(signal_handler, logger, send_queue))
    signal.signal(signal.SIGQUIT, partial(signal_handler, logger, send_queue))
    drop_pid(cred_dir)

    # start the scan tokens loop
    cred_process = Process(target=start_credmon_process, args=(send_queue,cred_dir,args,))
    cred_process.start()

    # wait for CREDMON_COMPLETE and signal the condor_master
    for tries in range(20*60):
        if os.path.exists(os.path.join(cred_dir, "CREDMON_COMPLETE")):
            try:
                logger.info("Sending DC_SET_READY message to master")
                htcondor.set_ready_state("Ready")
            except Exception:
                logger.warning("Could not send DC_SET_READY message to master")
            break
        time.sleep(1)

    # join scan tokens loop
    cred_process.join()

def start_credmon_process(q, cred_dir, args):

    credmon_incomplete(cred_dir)
    logger = setup_logging(**vars(args))

    # Each credmon needs to "stay in its lane" and be told which service providers to use.
    # - Administrators must spell out the provider configuration for the OAuth2 and Local credmons.
    # - Users can specify arbitrary provider names and configurations for the Vault credmon
    # For backward compat, if the providers aren't explicitly set, we use the following algorithm:
    # - If it's `scitokens`, we assume it is a local credmon.
    # - If config knob `FOO_CLIENT_ID` is non-empty, then we assume it's the OAuth2 credmon.
    # - Otherwise, it's the Vault credmon's.

    # The local credmon historically had a single provider; the Vault and OAuth credmon assumed they provided all tokens.
    local_provider_names = set(filter(None, re.compile(r"[\s,]+").split(htcondor.param.get("LOCAL_CREDMON_PROVIDER_NAMES", ""))))
    # Backward compat: if not providing a list, fallback to the singular config value.
    if not local_provider_names:
        local_provider_names = {htcondor.param.get("LOCAL_CREDMON_PROVIDER_NAME", "scitokens")}
    oauth2_provider_names = set(filter(None, re.compile(r"[\s,]+").split(htcondor.param.get("OAUTH2_CREDMON_PROVIDER_NAMES", "*"))))
    client_provider_names = set(filter(None, re.compile(r"[\s,]+").split(htcondor.param.get("CLIENT_CREDMON_PROVIDER_NAMES", ""))))


    credmons = []
    if OAuthCredmon is not None:
        credmons.append(OAuthCredmon(providers=oauth2_provider_names, cred_dir=cred_dir, args=args))
    if LocalCredmon is not None:
        for provider in local_provider_names:
            credmons.append(LocalCredmon(provider=provider, cred_dir=cred_dir, args=args))
    if ClientCredmon is not None:
        for provider in client_provider_names:
            credmons.append(ClientCredmon(provider=provider, cred_dir=cred_dir, args=args))

    vault_provider_names = set(filter(None, re.compile(r"[\s,]+").split(htcondor.param.get("VAULT_CREDMON_PROVIDER_NAMES", "*"))))
    if VaultCredmon is None and vault_provider_names and vault_provider_names != {"*"}:
        logger.warning("Vault credmon is configured via `VAULT_CREDMON_PROVIDER_NAMES` but the Vault credmon is not installed")
        vault_provider_names = set()
    elif VaultCredmon is not None:
        credmons.append(VaultCredmon(providers=vault_provider_names, cred_dir=cred_dir, args=args))

    # Check for overlapping credmons
    credmons_provider_names = {
        "local": local_provider_names,
        "oauth2": oauth2_provider_names,
        "client": client_provider_names,
        "vault": vault_provider_names,
    }
    checked_credmon_provider_names = {}
    for credmon, provider_names in credmons_provider_names.items():
        for provider_name in provider_names:
            existing_credmon = checked_credmon_provider_names.get(provider_name)
            if existing_credmon is None:
                checked_credmon_provider_names[provider_name] = credmon
            else:
                logger.warning(f"Credential providers named {provider_name} overlap between {existing_credmon} and {credmon} credmons, continuing anyway...")

    # set up scan tokens loop
    while True:
        logger.debug("Starting token scan loop")
        for credmon in credmons:
            try:
                logger.debug(f"Running {credmon.__class__.__name__}.scan_tokens()")
                credmon.scan_tokens()
            except Exception:
                logger.exception("Fatal error while scanning for tokens")
        logger.debug("Token scan loop complete, writing CREDMON_COMPLETE")
        credmon_complete(cred_dir)
        logger.info("Sleeping for 60 seconds unless interrupted")
        try:
            credmon_quit = q.get(block=True, timeout=60)
            logger.debug(f"Got {credmon_quit} from queue, waking up")
            if credmon_quit:
                return
        except Queue.Empty:
            pass

if __name__ == '__main__':
    main()
