SSH-Infrastructure / tools / generate-configs.py
Newer Older
280 lines | 9.627kb
Bogdan Timofte authored 2 weeks ago
1
#!/usr/bin/env python3
2
import argparse
3
import fnmatch
4
from pathlib import Path
5

            
6
import yaml
7

            
8

            
9
ROOT = Path(__file__).resolve().parents[1]
10

            
11

            
12
def load_inventory(path: Path) -> dict:
13
    with path.open("r", encoding="utf-8") as handle:
14
        data = yaml.safe_load(handle)
15
    if data.get("version") != 1:
16
        raise SystemExit("unsupported inventory version")
17
    return data
18

            
19

            
Bogdan Timofte authored 2 weeks ago
20
def merge_inventories(paths: list) -> dict:
21
    """Merge multiple inventory files, with later ones overriding earlier ones."""
22
    merged = {}
23
    for path in paths:
24
        data = load_inventory(path)
25
        for key in ("facts", "ssh_options", "defaults", "entrypoints", "jumps", "groups", "company_managed", "access_policies"):
26
            if key not in data:
27
                continue
28
            if key not in merged:
29
                merged[key] = {}
30
            if key in ("facts", "defaults", "company_managed"):
31
                merged[key].update(data[key])
32
            elif key in ("ssh_options", "entrypoints", "jumps", "groups", "access_policies"):
33
                merged[key].update(data[key])
34
    if "version" not in merged:
35
        merged["version"] = 1
36
    return merged
37

            
38

            
Bogdan Timofte authored 2 weeks ago
39
def fmt_bool(value: bool) -> str:
40
    return "yes" if value else "no"
41

            
42

            
43
def fmt_option(value) -> str:
44
    if isinstance(value, bool):
45
        return fmt_bool(value)
46
    return str(value)
47

            
48

            
49
def aliases_match_rule(aliases, rule):
50
    patterns = rule.get("patterns", [])
51
    if not patterns:
52
        return False
53
    return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)
54

            
55

            
56
def company_managed_rule(data, target, aliases, user, port):
57
    managed = data.get("company_managed", {}).get("jump_hosts", {})
58
    if target not in managed.get("inherit_globals_on_targets", []):
59
        return None
60

            
61
    for rule in managed.get("match_defaults", []):
62
        if rule.get("user") != user or rule.get("port") != port:
63
            continue
64
        if aliases_match_rule(aliases, rule):
65
            return rule
66
    return None
67

            
68

            
69
def aliases_for_host(host):
70
    aliases = [str(alias) for alias in host["aliases"]]
71
    if host["hostname"] not in aliases:
72
        aliases.append(host["hostname"])
73
    return aliases
74

            
75

            
76
def host_block(aliases, hostname, user=None, port=None, extra=None):
77
    lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f"    HostName {hostname}"]
78
    if user:
79
        lines.append(f"    User {user}")
80
    if port:
81
        lines.append(f"    Port {port}")
82
    auth = (extra or {}).pop("auth", None)
Bogdan Timofte authored 2 weeks ago
83
    proxy_jump = (extra or {}).pop("proxy_jump", None)
84
    route = (extra or {}).pop("route", None)
85

            
86
    if route:
87
        lines.append(f"    SetEnv SSH_ROUTE={route}")
Bogdan Timofte authored 2 weeks ago
88
    if auth == "password_interactive":
89
        lines.append("    SetEnv NG_SSH_AUTH=password-interactive")
90
        lines.append("    BatchMode no")
91
        lines.append("    PreferredAuthentications keyboard-interactive,password")
92
        lines.append("    PubkeyAuthentication no")
Bogdan Timofte authored 2 weeks ago
93
    if proxy_jump and proxy_jump != "none":
94
        lines.append(f"    ProxyJump {proxy_jump}")
Bogdan Timofte authored 2 weeks ago
95
    for key, value in (extra or {}).items():
96
        lines.append(f"    {key} {value}")
97
    lines.append("")
98
    return lines
99

            
100

            
101
def pattern_block(pattern, options):
102
    lines = [f"Host {pattern}"]
103
    if "connect_timeout" in options:
104
        lines.append(f"    ConnectTimeout {options['connect_timeout']}")
105
    if "connection_attempts" in options:
106
        lines.append(f"    ConnectionAttempts {options['connection_attempts']}")
107
    lines.append("")
108
    return lines
109

            
110

            
111
def generated_header(target, include_comments=True):
112
    if not include_comments:
113
        return []
114
    return [
115
        "# Generated by tools/generate-configs.py.",
116
        "# Do not edit this file directly; edit inventory/hosts.yaml.",
117
        f"# Target: {target}",
118
        "",
119
    ]
120

            
121

            
122
def emit_global_options(data, include_comments=True):
123
    blocks = data.get("ssh_options", {})
124
    if not blocks:
125
        return []
126

            
127
    lines = []
128
    if include_comments:
129
        lines.extend(["# Global SSH compatibility options", ""])
130
    for name, block in blocks.items():
131
        if include_comments:
132
            lines.append(f"# {name}: {block.get('description', '')}")
133
        lines.append("Host *")
134
        for key, value in block.get("options", {}).items():
135
            lines.append(f"    {key} {fmt_option(value)}")
136
        lines.append("")
137
    return lines
138

            
139

            
140
def inherit_globals(data, target):
141
    managed = data.get("company_managed", {}).get("jump_hosts", {})
142
    return target in managed.get("inherit_globals_on_targets", [])
143

            
144

            
145
def merged(defaults, group_defaults, host):
146
    result = dict(defaults)
147
    result.update(group_defaults or {})
148
    result.update(host)
149
    return result
150

            
151

            
152
def host_differs_from_defaults(host, defaults):
153
    for key in ("user", "port", "auth"):
154
        if key in host and host[key] != defaults.get(key):
155
            return True
156
    return False
157

            
158

            
159
def should_emit_host_on_target(data, target, group_defaults, host):
160
    if target not in ("j1", "j2"):
161
        return True
162

            
163
    baseline = merged(data["defaults"]["final_host"], group_defaults, {})
164
    return host_differs_from_defaults(host, baseline)
165

            
166

            
167
def emit_entrypoints(data, include_comments=True):
168
    lines = ["# Entrypoints", ""] if include_comments else []
169
    for host in data["entrypoints"].values():
170
        extra = {}
171
        if host.get("identity_file"):
172
            extra["IdentityFile"] = host["identity_file"]
173
        if "identities_only" in host:
174
            extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
175
        lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
176
    return lines
177

            
178

            
179
def emit_jumps(data, include_comments=True):
180
    lines = ["# Jump hosts", ""] if include_comments else []
181
    defaults = data["defaults"]["jump"]
182
    for jump in data["jumps"].values():
183
        item = merged(defaults, {}, jump)
184
        lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port")))
185
    return lines
186

            
187

            
188
def emit_hosts_for_group(data, group, target, defaults):
189
    group_defaults = group.get("defaults", {})
190
    lines = []
191
    for host in group.get("hosts", {}).values():
192
        if not should_emit_host_on_target(data, target, group_defaults, host):
193
            continue
194
        item = merged(defaults, group_defaults, host)
195
        aliases = aliases_for_host(item)
196
        extra = {}
197
        if item.get("auth"):
198
            extra["auth"] = item["auth"]
Bogdan Timofte authored 2 weeks ago
199
        if item.get("route"):
200
            extra["route"] = item["route"]
Bogdan Timofte authored 2 weeks ago
201
        user = item.get("user")
202
        port = item.get("port")
203
        if company_managed_rule(data, target, aliases, user, port):
204
            user = None
205
            port = None
206
        lines.extend(host_block(aliases, item["hostname"], user, port, extra))
207
    for pattern, options in group.get("patterns", {}).items():
208
        lines.extend(pattern_block(pattern, options))
209
    return lines
210

            
211

            
212
def emit_groups(data, target=None, include_comments=True):
213
    lines = []
214
    defaults = data["defaults"]["final_host"]
215
    metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
216
    queue = [(name, group) for name, group in data["groups"].items()]
217

            
218
    while queue:
219
        group_name, group = queue.pop(0)
220
        group_lines = emit_hosts_for_group(data, group, target, defaults)
221
        if group_lines:
222
            if include_comments:
223
                lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
224
            lines.extend(group_lines)
225

            
226
        for child_name, child in group.items():
227
            if child_name in metadata_keys or not isinstance(child, dict):
228
                continue
229
            if "hosts" in child:
230
                queue.append((f"{group_name}.{child_name}", child))
231
    return lines
232

            
233

            
234
def write(path: Path, lines):
235
    path.parent.mkdir(parents=True, exist_ok=True)
236
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
237

            
238

            
239
def generate(data, output_dir: Path):
240
    final_groups = emit_groups(data)
241

            
242
    client = generated_header("client")
243
    client.extend(emit_global_options(data))
244
    client.extend(emit_entrypoints(data))
245
    client.extend(emit_jumps(data))
246
    client.extend(final_groups)
247
    write(output_dir / "client.conf", client)
248

            
249
    is_jumper = generated_header("is-jumper")
250
    is_jumper.extend(emit_global_options(data))
251
    is_jumper.extend(emit_jumps(data))
252
    write(output_dir / "is-jumper.conf", is_jumper)
253

            
254
    for target in ("j1", "j2"):
255
        lines = generated_header(target, include_comments=False)
256
        if inherit_globals(data, target):
257
            pass
258
        else:
259
            lines.extend(emit_global_options(data, include_comments=False))
260
        lines.extend(emit_groups(data, target, include_comments=False))
261
        write(output_dir / f"{target}.conf", lines)
262

            
263

            
264
def main():
265
    parser = argparse.ArgumentParser()
Bogdan Timofte authored 2 weeks ago
266
    parser.add_argument("--inventory", action="append", type=Path, help="Inventory file(s) to load (can be used multiple times)")
Bogdan Timofte authored 2 weeks ago
267
    parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
268
    args = parser.parse_args()
269

            
Bogdan Timofte authored 2 weeks ago
270
    inventories = args.inventory if args.inventory else [ROOT / "inventory" / "hosts.yaml"]
271
    local_inventory = ROOT / "inventory" / "hosts-local.yaml"
272
    if local_inventory.exists() and local_inventory not in inventories:
273
        inventories.append(local_inventory)
274

            
275
    data = merge_inventories(inventories) if len(inventories) > 1 else load_inventory(inventories[0])
276
    generate(data, args.output_dir)
Bogdan Timofte authored 2 weeks ago
277

            
278

            
279
if __name__ == "__main__":
280
    main()