1 contributor
#!/usr/bin/env python3
import argparse
import fnmatch
from pathlib import Path
import yaml
ROOT = Path(__file__).resolve().parents[1]
def load_inventory(path: Path) -> dict:
with path.open("r", encoding="utf-8") as handle:
data = yaml.safe_load(handle)
if data.get("version") != 1:
raise SystemExit("unsupported inventory version")
return data
def merge_inventories(paths: list) -> dict:
"""Merge multiple inventory files, with later ones overriding earlier ones."""
merged = {}
for path in paths:
data = load_inventory(path)
for key in ("facts", "ssh_options", "defaults", "entrypoints", "jumps", "groups", "company_managed", "access_policies"):
if key not in data:
continue
if key not in merged:
merged[key] = {}
if key in ("facts", "defaults", "company_managed"):
merged[key].update(data[key])
elif key in ("ssh_options", "entrypoints", "jumps", "groups", "access_policies"):
merged[key].update(data[key])
if "version" not in merged:
merged["version"] = 1
return merged
def fmt_bool(value: bool) -> str:
return "yes" if value else "no"
def fmt_option(value) -> str:
if isinstance(value, bool):
return fmt_bool(value)
return str(value)
def aliases_match_rule(aliases, rule):
patterns = rule.get("patterns", [])
if not patterns:
return False
return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)
def company_managed_rule(data, target, aliases, user, port):
managed = data.get("company_managed", {}).get("jump_hosts", {})
if target not in managed.get("inherit_globals_on_targets", []):
return None
for rule in managed.get("match_defaults", []):
if rule.get("user") != user or rule.get("port") != port:
continue
if aliases_match_rule(aliases, rule):
return rule
return None
def aliases_for_host(host):
aliases = [str(alias) for alias in host["aliases"]]
if host["hostname"] not in aliases:
aliases.append(host["hostname"])
return aliases
def host_block(aliases, hostname, user=None, port=None, extra=None):
lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f" HostName {hostname}"]
if user:
lines.append(f" User {user}")
if port:
lines.append(f" Port {port}")
auth = (extra or {}).pop("auth", None)
proxy_jump = (extra or {}).pop("proxy_jump", None)
route = (extra or {}).pop("route", None)
if route:
lines.append(f" SetEnv SSH_ROUTE={route}")
if auth == "password_interactive":
lines.append(" SetEnv NG_SSH_AUTH=password-interactive")
lines.append(" BatchMode no")
lines.append(" PreferredAuthentications keyboard-interactive,password")
lines.append(" PubkeyAuthentication no")
if proxy_jump and proxy_jump != "none":
lines.append(f" ProxyJump {proxy_jump}")
for key, value in (extra or {}).items():
lines.append(f" {key} {value}")
lines.append("")
return lines
def pattern_block(pattern, options):
lines = [f"Host {pattern}"]
if "connect_timeout" in options:
lines.append(f" ConnectTimeout {options['connect_timeout']}")
if "connection_attempts" in options:
lines.append(f" ConnectionAttempts {options['connection_attempts']}")
lines.append("")
return lines
def generated_header(target, include_comments=True):
if not include_comments:
return []
return [
"# Generated by tools/generate-configs.py.",
"# Do not edit this file directly; edit inventory/hosts.yaml.",
f"# Target: {target}",
"",
]
def emit_global_options(data, include_comments=True):
blocks = data.get("ssh_options", {})
if not blocks:
return []
lines = []
if include_comments:
lines.extend(["# Global SSH compatibility options", ""])
for name, block in blocks.items():
if include_comments:
lines.append(f"# {name}: {block.get('description', '')}")
lines.append("Host *")
for key, value in block.get("options", {}).items():
lines.append(f" {key} {fmt_option(value)}")
lines.append("")
return lines
def inherit_globals(data, target):
managed = data.get("company_managed", {}).get("jump_hosts", {})
return target in managed.get("inherit_globals_on_targets", [])
def merged(defaults, group_defaults, host):
result = dict(defaults)
result.update(group_defaults or {})
result.update(host)
return result
def host_differs_from_defaults(host, defaults):
for key in ("user", "port", "auth"):
if key in host and host[key] != defaults.get(key):
return True
return False
def should_emit_host_on_target(data, target, group_defaults, host):
if target not in ("j1", "j2"):
return True
baseline = merged(data["defaults"]["final_host"], group_defaults, {})
return host_differs_from_defaults(host, baseline)
def emit_entrypoints(data, include_comments=True):
lines = ["# Entrypoints", ""] if include_comments else []
for host in data["entrypoints"].values():
extra = {}
if host.get("identity_file"):
extra["IdentityFile"] = host["identity_file"]
if "identities_only" in host:
extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
return lines
def emit_jumps(data, include_comments=True):
lines = ["# Jump hosts", ""] if include_comments else []
defaults = data["defaults"]["jump"]
for jump in data["jumps"].values():
item = merged(defaults, {}, jump)
lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port")))
return lines
def emit_hosts_for_group(data, group, target, defaults):
group_defaults = group.get("defaults", {})
lines = []
for host in group.get("hosts", {}).values():
if not should_emit_host_on_target(data, target, group_defaults, host):
continue
item = merged(defaults, group_defaults, host)
aliases = aliases_for_host(item)
extra = {}
if item.get("auth"):
extra["auth"] = item["auth"]
if item.get("route"):
extra["route"] = item["route"]
user = item.get("user")
port = item.get("port")
if company_managed_rule(data, target, aliases, user, port):
user = None
port = None
lines.extend(host_block(aliases, item["hostname"], user, port, extra))
for pattern, options in group.get("patterns", {}).items():
lines.extend(pattern_block(pattern, options))
return lines
def emit_groups(data, target=None, include_comments=True):
lines = []
defaults = data["defaults"]["final_host"]
metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
queue = [(name, group) for name, group in data["groups"].items()]
while queue:
group_name, group = queue.pop(0)
group_lines = emit_hosts_for_group(data, group, target, defaults)
if group_lines:
if include_comments:
lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
lines.extend(group_lines)
for child_name, child in group.items():
if child_name in metadata_keys or not isinstance(child, dict):
continue
if "hosts" in child:
queue.append((f"{group_name}.{child_name}", child))
return lines
def write(path: Path, lines):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
def generate(data, output_dir: Path):
final_groups = emit_groups(data)
client = generated_header("client")
client.extend(emit_global_options(data))
client.extend(emit_entrypoints(data))
client.extend(emit_jumps(data))
client.extend(final_groups)
write(output_dir / "client.conf", client)
is_jumper = generated_header("is-jumper")
is_jumper.extend(emit_global_options(data))
is_jumper.extend(emit_jumps(data))
write(output_dir / "is-jumper.conf", is_jumper)
for target in ("j1", "j2"):
lines = generated_header(target, include_comments=False)
if inherit_globals(data, target):
pass
else:
lines.extend(emit_global_options(data, include_comments=False))
lines.extend(emit_groups(data, target, include_comments=False))
write(output_dir / f"{target}.conf", lines)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--inventory", action="append", type=Path, help="Inventory file(s) to load (can be used multiple times)")
parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
args = parser.parse_args()
inventories = args.inventory if args.inventory else [ROOT / "inventory" / "hosts.yaml"]
local_inventory = ROOT / "inventory" / "hosts-local.yaml"
if local_inventory.exists() and local_inventory not in inventories:
inventories.append(local_inventory)
data = merge_inventories(inventories) if len(inventories) > 1 else load_inventory(inventories[0])
generate(data, args.output_dir)
if __name__ == "__main__":
main()