SSH-Infrastructure / tools / generate-configs.py
Newer Older
302 lines | 10.65kb
Bogdan Timofte authored 2 weeks ago
1
#!/usr/bin/env python3
2
import argparse
3
import fnmatch
4
from pathlib import Path
5

            
6
import yaml
7

            
8

            
9
ROOT = Path(__file__).resolve().parents[1]
10

            
11

            
12
def load_inventory(path: Path) -> dict:
13
    with path.open("r", encoding="utf-8") as handle:
14
        data = yaml.safe_load(handle)
15
    if data.get("version") != 1:
16
        raise SystemExit("unsupported inventory version")
17
    return data
18

            
19

            
Bogdan Timofte authored 2 weeks ago
20
def merge_inventories(paths: list) -> dict:
21
    """Merge multiple inventory files, with later ones overriding earlier ones."""
22
    merged = {}
23
    for path in paths:
24
        data = load_inventory(path)
25
        for key in ("facts", "ssh_options", "defaults", "entrypoints", "jumps", "groups", "company_managed", "access_policies"):
26
            if key not in data:
27
                continue
28
            if key not in merged:
29
                merged[key] = {}
30
            if key in ("facts", "defaults", "company_managed"):
31
                merged[key].update(data[key])
32
            elif key in ("ssh_options", "entrypoints", "jumps", "groups", "access_policies"):
33
                merged[key].update(data[key])
34
    if "version" not in merged:
35
        merged["version"] = 1
36
    return merged
37

            
38

            
Bogdan Timofte authored 2 weeks ago
39
def fmt_bool(value: bool) -> str:
40
    return "yes" if value else "no"
41

            
42

            
43
def fmt_option(value) -> str:
44
    if isinstance(value, bool):
45
        return fmt_bool(value)
46
    return str(value)
47

            
48

            
49
def aliases_match_rule(aliases, rule):
50
    patterns = rule.get("patterns", [])
51
    if not patterns:
52
        return False
53
    return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)
54

            
55

            
56
def company_managed_rule(data, target, aliases, user, port):
57
    managed = data.get("company_managed", {}).get("jump_hosts", {})
58
    if target not in managed.get("inherit_globals_on_targets", []):
59
        return None
60

            
61
    for rule in managed.get("match_defaults", []):
62
        if rule.get("user") != user or rule.get("port") != port:
63
            continue
64
        if aliases_match_rule(aliases, rule):
65
            return rule
66
    return None
67

            
68

            
69
def aliases_for_host(host):
70
    aliases = [str(alias) for alias in host["aliases"]]
71
    if host["hostname"] not in aliases:
72
        aliases.append(host["hostname"])
73
    return aliases
74

            
75

            
76
def host_block(aliases, hostname, user=None, port=None, extra=None):
77
    lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f"    HostName {hostname}"]
78
    if user:
79
        lines.append(f"    User {user}")
80
    if port:
81
        lines.append(f"    Port {port}")
82
    auth = (extra or {}).pop("auth", None)
Bogdan Timofte authored 2 weeks ago
83
    proxy_jump = (extra or {}).pop("proxy_jump", None)
84
    route = (extra or {}).pop("route", None)
Bogdan Timofte authored 2 weeks ago
85
    identity_file = (extra or {}).pop("identity_file", None)
Bogdan Timofte authored 2 weeks ago
86
    identities_only = (extra or {}).pop("identities_only", None)
Bogdan Timofte authored 2 weeks ago
87

            
88
    if route:
89
        lines.append(f"    SetEnv SSH_ROUTE={route}")
Bogdan Timofte authored 2 weeks ago
90
    if identity_file:
91
        lines.append(f"    IdentityFile {identity_file}")
Bogdan Timofte authored 2 weeks ago
92
    if identities_only:
93
        lines.append(f"    IdentitiesOnly {identities_only}")
Bogdan Timofte authored 2 weeks ago
94
    if auth == "password_interactive":
95
        lines.append("    SetEnv NG_SSH_AUTH=password-interactive")
96
        lines.append("    BatchMode no")
97
        lines.append("    PreferredAuthentications keyboard-interactive,password")
98
        lines.append("    PubkeyAuthentication no")
Bogdan Timofte authored 2 weeks ago
99
    if proxy_jump and proxy_jump != "none":
100
        lines.append(f"    ProxyJump {proxy_jump}")
Bogdan Timofte authored 2 weeks ago
101
    for key, value in (extra or {}).items():
102
        lines.append(f"    {key} {value}")
103
    lines.append("")
104
    return lines
105

            
106

            
107
def pattern_block(pattern, options):
108
    lines = [f"Host {pattern}"]
109
    if "connect_timeout" in options:
110
        lines.append(f"    ConnectTimeout {options['connect_timeout']}")
111
    if "connection_attempts" in options:
112
        lines.append(f"    ConnectionAttempts {options['connection_attempts']}")
113
    lines.append("")
114
    return lines
115

            
116

            
117
def generated_header(target, include_comments=True):
118
    if not include_comments:
119
        return []
120
    return [
121
        "# Generated by tools/generate-configs.py.",
122
        "# Do not edit this file directly; edit inventory/hosts.yaml.",
123
        f"# Target: {target}",
124
        "",
125
    ]
126

            
127

            
128
def emit_global_options(data, include_comments=True):
129
    blocks = data.get("ssh_options", {})
130
    if not blocks:
131
        return []
132

            
133
    lines = []
134
    if include_comments:
135
        lines.extend(["# Global SSH compatibility options", ""])
136
    for name, block in blocks.items():
137
        if include_comments:
138
            lines.append(f"# {name}: {block.get('description', '')}")
139
        lines.append("Host *")
140
        for key, value in block.get("options", {}).items():
141
            lines.append(f"    {key} {fmt_option(value)}")
142
        lines.append("")
143
    return lines
144

            
145

            
146
def inherit_globals(data, target):
147
    managed = data.get("company_managed", {}).get("jump_hosts", {})
148
    return target in managed.get("inherit_globals_on_targets", [])
149

            
150

            
151
def merged(defaults, group_defaults, host):
152
    result = dict(defaults)
153
    result.update(group_defaults or {})
154
    result.update(host)
155
    return result
156

            
157

            
158
def host_differs_from_defaults(host, defaults):
159
    for key in ("user", "port", "auth"):
160
        if key in host and host[key] != defaults.get(key):
161
            return True
162
    return False
163

            
164

            
165
def should_emit_host_on_target(data, target, group_defaults, host):
166
    if target not in ("j1", "j2"):
167
        return True
168

            
169
    baseline = merged(data["defaults"]["final_host"], group_defaults, {})
170
    return host_differs_from_defaults(host, baseline)
171

            
172

            
173
def emit_entrypoints(data, include_comments=True):
174
    lines = ["# Entrypoints", ""] if include_comments else []
175
    for host in data["entrypoints"].values():
176
        extra = {}
177
        if host.get("identity_file"):
178
            extra["IdentityFile"] = host["identity_file"]
179
        if "identities_only" in host:
180
            extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
181
        lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
182
    return lines
183

            
184

            
185
def emit_jumps(data, include_comments=True):
186
    lines = ["# Jump hosts", ""] if include_comments else []
187
    defaults = data["defaults"]["jump"]
188
    for jump in data["jumps"].values():
189
        item = merged(defaults, {}, jump)
Bogdan Timofte authored 2 weeks ago
190
        extra = {}
191
        if item.get("proxy_jump"):
192
            extra["proxy_jump"] = item["proxy_jump"]
193
        if item.get("identity_file"):
194
            extra["identity_file"] = item["identity_file"]
195
        if "identities_only" in item:
196
            extra["identities_only"] = fmt_bool(item["identities_only"])
197
        lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port"), extra))
Bogdan Timofte authored 2 weeks ago
198
    return lines
199

            
200

            
201
def emit_hosts_for_group(data, group, target, defaults):
202
    group_defaults = group.get("defaults", {})
203
    lines = []
Bogdan Timofte authored 2 weeks ago
204
    group_jump = group.get("default_jump")
Bogdan Timofte authored 2 weeks ago
205
    for host in group.get("hosts", {}).values():
206
        if not should_emit_host_on_target(data, target, group_defaults, host):
207
            continue
208
        item = merged(defaults, group_defaults, host)
209
        aliases = aliases_for_host(item)
210
        extra = {}
211
        if item.get("auth"):
212
            extra["auth"] = item["auth"]
Bogdan Timofte authored 2 weeks ago
213
        if item.get("route"):
214
            extra["route"] = item["route"]
Bogdan Timofte authored 2 weeks ago
215
        if item.get("identity_file"):
216
            extra["identity_file"] = item["identity_file"]
Bogdan Timofte authored 2 weeks ago
217
        if "identities_only" in item:
218
            extra["identities_only"] = fmt_bool(item["identities_only"])
219
        if item.get("proxy_jump"):
220
            extra["proxy_jump"] = item["proxy_jump"]
221
        elif group_jump:
222
            extra["proxy_jump"] = group_jump
Bogdan Timofte authored 2 weeks ago
223
        user = item.get("user")
224
        port = item.get("port")
225
        if company_managed_rule(data, target, aliases, user, port):
226
            user = None
227
            port = None
228
        lines.extend(host_block(aliases, item["hostname"], user, port, extra))
229
    for pattern, options in group.get("patterns", {}).items():
230
        lines.extend(pattern_block(pattern, options))
231
    return lines
232

            
233

            
234
def emit_groups(data, target=None, include_comments=True):
235
    lines = []
236
    defaults = data["defaults"]["final_host"]
237
    metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
238
    queue = [(name, group) for name, group in data["groups"].items()]
239

            
240
    while queue:
241
        group_name, group = queue.pop(0)
242
        group_lines = emit_hosts_for_group(data, group, target, defaults)
243
        if group_lines:
244
            if include_comments:
245
                lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
246
            lines.extend(group_lines)
247

            
248
        for child_name, child in group.items():
249
            if child_name in metadata_keys or not isinstance(child, dict):
250
                continue
251
            if "hosts" in child:
252
                queue.append((f"{group_name}.{child_name}", child))
253
    return lines
254

            
255

            
256
def write(path: Path, lines):
257
    path.parent.mkdir(parents=True, exist_ok=True)
258
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
259

            
260

            
261
def generate(data, output_dir: Path):
262
    final_groups = emit_groups(data)
263

            
264
    client = generated_header("client")
265
    client.extend(emit_global_options(data))
266
    client.extend(emit_entrypoints(data))
267
    client.extend(emit_jumps(data))
268
    client.extend(final_groups)
269
    write(output_dir / "client.conf", client)
270

            
271
    is_jumper = generated_header("is-jumper")
272
    is_jumper.extend(emit_global_options(data))
273
    is_jumper.extend(emit_jumps(data))
274
    write(output_dir / "is-jumper.conf", is_jumper)
275

            
276
    for target in ("j1", "j2"):
277
        lines = generated_header(target, include_comments=False)
278
        if inherit_globals(data, target):
279
            pass
280
        else:
281
            lines.extend(emit_global_options(data, include_comments=False))
282
        lines.extend(emit_groups(data, target, include_comments=False))
283
        write(output_dir / f"{target}.conf", lines)
284

            
285

            
286
def main():
287
    parser = argparse.ArgumentParser()
Bogdan Timofte authored 2 weeks ago
288
    parser.add_argument("--inventory", action="append", type=Path, help="Inventory file(s) to load (can be used multiple times)")
Bogdan Timofte authored 2 weeks ago
289
    parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
290
    args = parser.parse_args()
291

            
Bogdan Timofte authored 2 weeks ago
292
    inventories = args.inventory if args.inventory else [ROOT / "inventory" / "hosts.yaml"]
293
    local_inventory = ROOT / "inventory" / "hosts-local.yaml"
294
    if local_inventory.exists() and local_inventory not in inventories:
295
        inventories.append(local_inventory)
296

            
297
    data = merge_inventories(inventories) if len(inventories) > 1 else load_inventory(inventories[0])
298
    generate(data, args.output_dir)
Bogdan Timofte authored 2 weeks ago
299

            
300

            
301
if __name__ == "__main__":
302
    main()