SSH-Infrastructure / tools / generate-configs.py
1 contributor
271 lines | 9.289kb
#!/usr/bin/env python3
import argparse
import fnmatch
from pathlib import Path

import yaml


ROOT = Path(__file__).resolve().parents[1]


def load_inventory(path: Path) -> dict:
    with path.open("r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle)
    if data.get("version") != 1:
        raise SystemExit("unsupported inventory version")
    return data


def merge_inventories(paths: list) -> dict:
    """Merge multiple inventory files, with later ones overriding earlier ones."""
    merged = {}
    for path in paths:
        data = load_inventory(path)
        for key in ("facts", "ssh_options", "defaults", "entrypoints", "jumps", "groups", "company_managed", "access_policies"):
            if key not in data:
                continue
            if key not in merged:
                merged[key] = {}
            if key in ("facts", "defaults", "company_managed"):
                merged[key].update(data[key])
            elif key in ("ssh_options", "entrypoints", "jumps", "groups", "access_policies"):
                merged[key].update(data[key])
    if "version" not in merged:
        merged["version"] = 1
    return merged


def fmt_bool(value: bool) -> str:
    return "yes" if value else "no"


def fmt_option(value) -> str:
    if isinstance(value, bool):
        return fmt_bool(value)
    return str(value)


def aliases_match_rule(aliases, rule):
    patterns = rule.get("patterns", [])
    if not patterns:
        return False
    return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)


def company_managed_rule(data, target, aliases, user, port):
    managed = data.get("company_managed", {}).get("jump_hosts", {})
    if target not in managed.get("inherit_globals_on_targets", []):
        return None

    for rule in managed.get("match_defaults", []):
        if rule.get("user") != user or rule.get("port") != port:
            continue
        if aliases_match_rule(aliases, rule):
            return rule
    return None


def aliases_for_host(host):
    aliases = [str(alias) for alias in host["aliases"]]
    if host["hostname"] not in aliases:
        aliases.append(host["hostname"])
    return aliases


def host_block(aliases, hostname, user=None, port=None, extra=None):
    lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f"    HostName {hostname}"]
    if user:
        lines.append(f"    User {user}")
    if port:
        lines.append(f"    Port {port}")
    auth = (extra or {}).pop("auth", None)
    if auth == "password_interactive":
        lines.append("    SetEnv NG_SSH_AUTH=password-interactive")
        lines.append("    BatchMode no")
        lines.append("    PreferredAuthentications keyboard-interactive,password")
        lines.append("    PubkeyAuthentication no")
    for key, value in (extra or {}).items():
        lines.append(f"    {key} {value}")
    lines.append("")
    return lines


def pattern_block(pattern, options):
    lines = [f"Host {pattern}"]
    if "connect_timeout" in options:
        lines.append(f"    ConnectTimeout {options['connect_timeout']}")
    if "connection_attempts" in options:
        lines.append(f"    ConnectionAttempts {options['connection_attempts']}")
    lines.append("")
    return lines


def generated_header(target, include_comments=True):
    if not include_comments:
        return []
    return [
        "# Generated by tools/generate-configs.py.",
        "# Do not edit this file directly; edit inventory/hosts.yaml.",
        f"# Target: {target}",
        "",
    ]


def emit_global_options(data, include_comments=True):
    blocks = data.get("ssh_options", {})
    if not blocks:
        return []

    lines = []
    if include_comments:
        lines.extend(["# Global SSH compatibility options", ""])
    for name, block in blocks.items():
        if include_comments:
            lines.append(f"# {name}: {block.get('description', '')}")
        lines.append("Host *")
        for key, value in block.get("options", {}).items():
            lines.append(f"    {key} {fmt_option(value)}")
        lines.append("")
    return lines


def inherit_globals(data, target):
    managed = data.get("company_managed", {}).get("jump_hosts", {})
    return target in managed.get("inherit_globals_on_targets", [])


def merged(defaults, group_defaults, host):
    result = dict(defaults)
    result.update(group_defaults or {})
    result.update(host)
    return result


def host_differs_from_defaults(host, defaults):
    for key in ("user", "port", "auth"):
        if key in host and host[key] != defaults.get(key):
            return True
    return False


def should_emit_host_on_target(data, target, group_defaults, host):
    if target not in ("j1", "j2"):
        return True

    baseline = merged(data["defaults"]["final_host"], group_defaults, {})
    return host_differs_from_defaults(host, baseline)


def emit_entrypoints(data, include_comments=True):
    lines = ["# Entrypoints", ""] if include_comments else []
    for host in data["entrypoints"].values():
        extra = {}
        if host.get("identity_file"):
            extra["IdentityFile"] = host["identity_file"]
        if "identities_only" in host:
            extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
        lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
    return lines


def emit_jumps(data, include_comments=True):
    lines = ["# Jump hosts", ""] if include_comments else []
    defaults = data["defaults"]["jump"]
    for jump in data["jumps"].values():
        item = merged(defaults, {}, jump)
        lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port")))
    return lines


def emit_hosts_for_group(data, group, target, defaults):
    group_defaults = group.get("defaults", {})
    lines = []
    for host in group.get("hosts", {}).values():
        if not should_emit_host_on_target(data, target, group_defaults, host):
            continue
        item = merged(defaults, group_defaults, host)
        aliases = aliases_for_host(item)
        extra = {}
        if item.get("auth"):
            extra["auth"] = item["auth"]
        user = item.get("user")
        port = item.get("port")
        if company_managed_rule(data, target, aliases, user, port):
            user = None
            port = None
        lines.extend(host_block(aliases, item["hostname"], user, port, extra))
    for pattern, options in group.get("patterns", {}).items():
        lines.extend(pattern_block(pattern, options))
    return lines


def emit_groups(data, target=None, include_comments=True):
    lines = []
    defaults = data["defaults"]["final_host"]
    metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
    queue = [(name, group) for name, group in data["groups"].items()]

    while queue:
        group_name, group = queue.pop(0)
        group_lines = emit_hosts_for_group(data, group, target, defaults)
        if group_lines:
            if include_comments:
                lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
            lines.extend(group_lines)

        for child_name, child in group.items():
            if child_name in metadata_keys or not isinstance(child, dict):
                continue
            if "hosts" in child:
                queue.append((f"{group_name}.{child_name}", child))
    return lines


def write(path: Path, lines):
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")


def generate(data, output_dir: Path):
    final_groups = emit_groups(data)

    client = generated_header("client")
    client.extend(emit_global_options(data))
    client.extend(emit_entrypoints(data))
    client.extend(emit_jumps(data))
    client.extend(final_groups)
    write(output_dir / "client.conf", client)

    is_jumper = generated_header("is-jumper")
    is_jumper.extend(emit_global_options(data))
    is_jumper.extend(emit_jumps(data))
    write(output_dir / "is-jumper.conf", is_jumper)

    for target in ("j1", "j2"):
        lines = generated_header(target, include_comments=False)
        if inherit_globals(data, target):
            pass
        else:
            lines.extend(emit_global_options(data, include_comments=False))
        lines.extend(emit_groups(data, target, include_comments=False))
        write(output_dir / f"{target}.conf", lines)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--inventory", action="append", type=Path, help="Inventory file(s) to load (can be used multiple times)")
    parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
    args = parser.parse_args()

    inventories = args.inventory if args.inventory else [ROOT / "inventory" / "hosts.yaml"]
    local_inventory = ROOT / "inventory" / "hosts-local.yaml"
    if local_inventory.exists() and local_inventory not in inventories:
        inventories.append(local_inventory)

    data = merge_inventories(inventories) if len(inventories) > 1 else load_inventory(inventories[0])
    generate(data, args.output_dir)


if __name__ == "__main__":
    main()