SSH-Infrastructure / tools / generate-configs.py
Newer Older
271 lines | 9.289kb
Bogdan Timofte authored 2 weeks ago
1
#!/usr/bin/env python3
2
import argparse
3
import fnmatch
4
from pathlib import Path
5

            
6
import yaml
7

            
8

            
9
ROOT = Path(__file__).resolve().parents[1]
10

            
11

            
12
def load_inventory(path: Path) -> dict:
13
    with path.open("r", encoding="utf-8") as handle:
14
        data = yaml.safe_load(handle)
15
    if data.get("version") != 1:
16
        raise SystemExit("unsupported inventory version")
17
    return data
18

            
19

            
Bogdan Timofte authored 2 weeks ago
20
def merge_inventories(paths: list) -> dict:
21
    """Merge multiple inventory files, with later ones overriding earlier ones."""
22
    merged = {}
23
    for path in paths:
24
        data = load_inventory(path)
25
        for key in ("facts", "ssh_options", "defaults", "entrypoints", "jumps", "groups", "company_managed", "access_policies"):
26
            if key not in data:
27
                continue
28
            if key not in merged:
29
                merged[key] = {}
30
            if key in ("facts", "defaults", "company_managed"):
31
                merged[key].update(data[key])
32
            elif key in ("ssh_options", "entrypoints", "jumps", "groups", "access_policies"):
33
                merged[key].update(data[key])
34
    if "version" not in merged:
35
        merged["version"] = 1
36
    return merged
37

            
38

            
Bogdan Timofte authored 2 weeks ago
39
def fmt_bool(value: bool) -> str:
40
    return "yes" if value else "no"
41

            
42

            
43
def fmt_option(value) -> str:
44
    if isinstance(value, bool):
45
        return fmt_bool(value)
46
    return str(value)
47

            
48

            
49
def aliases_match_rule(aliases, rule):
50
    patterns = rule.get("patterns", [])
51
    if not patterns:
52
        return False
53
    return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)
54

            
55

            
56
def company_managed_rule(data, target, aliases, user, port):
57
    managed = data.get("company_managed", {}).get("jump_hosts", {})
58
    if target not in managed.get("inherit_globals_on_targets", []):
59
        return None
60

            
61
    for rule in managed.get("match_defaults", []):
62
        if rule.get("user") != user or rule.get("port") != port:
63
            continue
64
        if aliases_match_rule(aliases, rule):
65
            return rule
66
    return None
67

            
68

            
69
def aliases_for_host(host):
70
    aliases = [str(alias) for alias in host["aliases"]]
71
    if host["hostname"] not in aliases:
72
        aliases.append(host["hostname"])
73
    return aliases
74

            
75

            
76
def host_block(aliases, hostname, user=None, port=None, extra=None):
77
    lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f"    HostName {hostname}"]
78
    if user:
79
        lines.append(f"    User {user}")
80
    if port:
81
        lines.append(f"    Port {port}")
82
    auth = (extra or {}).pop("auth", None)
83
    if auth == "password_interactive":
84
        lines.append("    SetEnv NG_SSH_AUTH=password-interactive")
85
        lines.append("    BatchMode no")
86
        lines.append("    PreferredAuthentications keyboard-interactive,password")
87
        lines.append("    PubkeyAuthentication no")
88
    for key, value in (extra or {}).items():
89
        lines.append(f"    {key} {value}")
90
    lines.append("")
91
    return lines
92

            
93

            
94
def pattern_block(pattern, options):
95
    lines = [f"Host {pattern}"]
96
    if "connect_timeout" in options:
97
        lines.append(f"    ConnectTimeout {options['connect_timeout']}")
98
    if "connection_attempts" in options:
99
        lines.append(f"    ConnectionAttempts {options['connection_attempts']}")
100
    lines.append("")
101
    return lines
102

            
103

            
104
def generated_header(target, include_comments=True):
105
    if not include_comments:
106
        return []
107
    return [
108
        "# Generated by tools/generate-configs.py.",
109
        "# Do not edit this file directly; edit inventory/hosts.yaml.",
110
        f"# Target: {target}",
111
        "",
112
    ]
113

            
114

            
115
def emit_global_options(data, include_comments=True):
116
    blocks = data.get("ssh_options", {})
117
    if not blocks:
118
        return []
119

            
120
    lines = []
121
    if include_comments:
122
        lines.extend(["# Global SSH compatibility options", ""])
123
    for name, block in blocks.items():
124
        if include_comments:
125
            lines.append(f"# {name}: {block.get('description', '')}")
126
        lines.append("Host *")
127
        for key, value in block.get("options", {}).items():
128
            lines.append(f"    {key} {fmt_option(value)}")
129
        lines.append("")
130
    return lines
131

            
132

            
133
def inherit_globals(data, target):
134
    managed = data.get("company_managed", {}).get("jump_hosts", {})
135
    return target in managed.get("inherit_globals_on_targets", [])
136

            
137

            
138
def merged(defaults, group_defaults, host):
139
    result = dict(defaults)
140
    result.update(group_defaults or {})
141
    result.update(host)
142
    return result
143

            
144

            
145
def host_differs_from_defaults(host, defaults):
146
    for key in ("user", "port", "auth"):
147
        if key in host and host[key] != defaults.get(key):
148
            return True
149
    return False
150

            
151

            
152
def should_emit_host_on_target(data, target, group_defaults, host):
153
    if target not in ("j1", "j2"):
154
        return True
155

            
156
    baseline = merged(data["defaults"]["final_host"], group_defaults, {})
157
    return host_differs_from_defaults(host, baseline)
158

            
159

            
160
def emit_entrypoints(data, include_comments=True):
161
    lines = ["# Entrypoints", ""] if include_comments else []
162
    for host in data["entrypoints"].values():
163
        extra = {}
164
        if host.get("identity_file"):
165
            extra["IdentityFile"] = host["identity_file"]
166
        if "identities_only" in host:
167
            extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
168
        lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
169
    return lines
170

            
171

            
172
def emit_jumps(data, include_comments=True):
173
    lines = ["# Jump hosts", ""] if include_comments else []
174
    defaults = data["defaults"]["jump"]
175
    for jump in data["jumps"].values():
176
        item = merged(defaults, {}, jump)
177
        lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port")))
178
    return lines
179

            
180

            
181
def emit_hosts_for_group(data, group, target, defaults):
182
    group_defaults = group.get("defaults", {})
183
    lines = []
184
    for host in group.get("hosts", {}).values():
185
        if not should_emit_host_on_target(data, target, group_defaults, host):
186
            continue
187
        item = merged(defaults, group_defaults, host)
188
        aliases = aliases_for_host(item)
189
        extra = {}
190
        if item.get("auth"):
191
            extra["auth"] = item["auth"]
192
        user = item.get("user")
193
        port = item.get("port")
194
        if company_managed_rule(data, target, aliases, user, port):
195
            user = None
196
            port = None
197
        lines.extend(host_block(aliases, item["hostname"], user, port, extra))
198
    for pattern, options in group.get("patterns", {}).items():
199
        lines.extend(pattern_block(pattern, options))
200
    return lines
201

            
202

            
203
def emit_groups(data, target=None, include_comments=True):
204
    lines = []
205
    defaults = data["defaults"]["final_host"]
206
    metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
207
    queue = [(name, group) for name, group in data["groups"].items()]
208

            
209
    while queue:
210
        group_name, group = queue.pop(0)
211
        group_lines = emit_hosts_for_group(data, group, target, defaults)
212
        if group_lines:
213
            if include_comments:
214
                lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
215
            lines.extend(group_lines)
216

            
217
        for child_name, child in group.items():
218
            if child_name in metadata_keys or not isinstance(child, dict):
219
                continue
220
            if "hosts" in child:
221
                queue.append((f"{group_name}.{child_name}", child))
222
    return lines
223

            
224

            
225
def write(path: Path, lines):
226
    path.parent.mkdir(parents=True, exist_ok=True)
227
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
228

            
229

            
230
def generate(data, output_dir: Path):
231
    final_groups = emit_groups(data)
232

            
233
    client = generated_header("client")
234
    client.extend(emit_global_options(data))
235
    client.extend(emit_entrypoints(data))
236
    client.extend(emit_jumps(data))
237
    client.extend(final_groups)
238
    write(output_dir / "client.conf", client)
239

            
240
    is_jumper = generated_header("is-jumper")
241
    is_jumper.extend(emit_global_options(data))
242
    is_jumper.extend(emit_jumps(data))
243
    write(output_dir / "is-jumper.conf", is_jumper)
244

            
245
    for target in ("j1", "j2"):
246
        lines = generated_header(target, include_comments=False)
247
        if inherit_globals(data, target):
248
            pass
249
        else:
250
            lines.extend(emit_global_options(data, include_comments=False))
251
        lines.extend(emit_groups(data, target, include_comments=False))
252
        write(output_dir / f"{target}.conf", lines)
253

            
254

            
255
def main():
256
    parser = argparse.ArgumentParser()
Bogdan Timofte authored 2 weeks ago
257
    parser.add_argument("--inventory", action="append", type=Path, help="Inventory file(s) to load (can be used multiple times)")
Bogdan Timofte authored 2 weeks ago
258
    parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
259
    args = parser.parse_args()
260

            
Bogdan Timofte authored 2 weeks ago
261
    inventories = args.inventory if args.inventory else [ROOT / "inventory" / "hosts.yaml"]
262
    local_inventory = ROOT / "inventory" / "hosts-local.yaml"
263
    if local_inventory.exists() and local_inventory not in inventories:
264
        inventories.append(local_inventory)
265

            
266
    data = merge_inventories(inventories) if len(inventories) > 1 else load_inventory(inventories[0])
267
    generate(data, args.output_dir)
Bogdan Timofte authored 2 weeks ago
268

            
269

            
270
if __name__ == "__main__":
271
    main()