SSH-Infrastructure / tools / generate-configs.py
Newer Older
246 lines | 8.118kb
Bogdan Timofte authored 2 weeks ago
1
#!/usr/bin/env python3
2
import argparse
3
import fnmatch
4
from pathlib import Path
5

            
6
import yaml
7

            
8

            
9
ROOT = Path(__file__).resolve().parents[1]
10

            
11

            
12
def load_inventory(path: Path) -> dict:
13
    with path.open("r", encoding="utf-8") as handle:
14
        data = yaml.safe_load(handle)
15
    if data.get("version") != 1:
16
        raise SystemExit("unsupported inventory version")
17
    return data
18

            
19

            
20
def fmt_bool(value: bool) -> str:
21
    return "yes" if value else "no"
22

            
23

            
24
def fmt_option(value) -> str:
25
    if isinstance(value, bool):
26
        return fmt_bool(value)
27
    return str(value)
28

            
29

            
30
def aliases_match_rule(aliases, rule):
31
    patterns = rule.get("patterns", [])
32
    if not patterns:
33
        return False
34
    return all(any(fnmatch.fnmatch(str(alias), pattern) for pattern in patterns) for alias in aliases)
35

            
36

            
37
def company_managed_rule(data, target, aliases, user, port):
38
    managed = data.get("company_managed", {}).get("jump_hosts", {})
39
    if target not in managed.get("inherit_globals_on_targets", []):
40
        return None
41

            
42
    for rule in managed.get("match_defaults", []):
43
        if rule.get("user") != user or rule.get("port") != port:
44
            continue
45
        if aliases_match_rule(aliases, rule):
46
            return rule
47
    return None
48

            
49

            
50
def aliases_for_host(host):
51
    aliases = [str(alias) for alias in host["aliases"]]
52
    if host["hostname"] not in aliases:
53
        aliases.append(host["hostname"])
54
    return aliases
55

            
56

            
57
def host_block(aliases, hostname, user=None, port=None, extra=None):
58
    lines = [f"Host {' '.join(str(alias) for alias in aliases)}", f"    HostName {hostname}"]
59
    if user:
60
        lines.append(f"    User {user}")
61
    if port:
62
        lines.append(f"    Port {port}")
63
    auth = (extra or {}).pop("auth", None)
64
    if auth == "password_interactive":
65
        lines.append("    SetEnv NG_SSH_AUTH=password-interactive")
66
        lines.append("    BatchMode no")
67
        lines.append("    PreferredAuthentications keyboard-interactive,password")
68
        lines.append("    PubkeyAuthentication no")
69
    for key, value in (extra or {}).items():
70
        lines.append(f"    {key} {value}")
71
    lines.append("")
72
    return lines
73

            
74

            
75
def pattern_block(pattern, options):
76
    lines = [f"Host {pattern}"]
77
    if "connect_timeout" in options:
78
        lines.append(f"    ConnectTimeout {options['connect_timeout']}")
79
    if "connection_attempts" in options:
80
        lines.append(f"    ConnectionAttempts {options['connection_attempts']}")
81
    lines.append("")
82
    return lines
83

            
84

            
85
def generated_header(target, include_comments=True):
86
    if not include_comments:
87
        return []
88
    return [
89
        "# Generated by tools/generate-configs.py.",
90
        "# Do not edit this file directly; edit inventory/hosts.yaml.",
91
        f"# Target: {target}",
92
        "",
93
    ]
94

            
95

            
96
def emit_global_options(data, include_comments=True):
97
    blocks = data.get("ssh_options", {})
98
    if not blocks:
99
        return []
100

            
101
    lines = []
102
    if include_comments:
103
        lines.extend(["# Global SSH compatibility options", ""])
104
    for name, block in blocks.items():
105
        if include_comments:
106
            lines.append(f"# {name}: {block.get('description', '')}")
107
        lines.append("Host *")
108
        for key, value in block.get("options", {}).items():
109
            lines.append(f"    {key} {fmt_option(value)}")
110
        lines.append("")
111
    return lines
112

            
113

            
114
def inherit_globals(data, target):
115
    managed = data.get("company_managed", {}).get("jump_hosts", {})
116
    return target in managed.get("inherit_globals_on_targets", [])
117

            
118

            
119
def merged(defaults, group_defaults, host):
120
    result = dict(defaults)
121
    result.update(group_defaults or {})
122
    result.update(host)
123
    return result
124

            
125

            
126
def host_differs_from_defaults(host, defaults):
127
    for key in ("user", "port", "auth"):
128
        if key in host and host[key] != defaults.get(key):
129
            return True
130
    return False
131

            
132

            
133
def should_emit_host_on_target(data, target, group_defaults, host):
134
    if target not in ("j1", "j2"):
135
        return True
136

            
137
    baseline = merged(data["defaults"]["final_host"], group_defaults, {})
138
    return host_differs_from_defaults(host, baseline)
139

            
140

            
141
def emit_entrypoints(data, include_comments=True):
142
    lines = ["# Entrypoints", ""] if include_comments else []
143
    for host in data["entrypoints"].values():
144
        extra = {}
145
        if host.get("identity_file"):
146
            extra["IdentityFile"] = host["identity_file"]
147
        if "identities_only" in host:
148
            extra["IdentitiesOnly"] = fmt_bool(host["identities_only"])
149
        lines.extend(host_block(aliases_for_host(host), host["hostname"], host.get("user"), host.get("port"), extra))
150
    return lines
151

            
152

            
153
def emit_jumps(data, include_comments=True):
154
    lines = ["# Jump hosts", ""] if include_comments else []
155
    defaults = data["defaults"]["jump"]
156
    for jump in data["jumps"].values():
157
        item = merged(defaults, {}, jump)
158
        lines.extend(host_block(aliases_for_host(item), item["hostname"], item.get("user"), item.get("port")))
159
    return lines
160

            
161

            
162
def emit_hosts_for_group(data, group, target, defaults):
163
    group_defaults = group.get("defaults", {})
164
    lines = []
165
    for host in group.get("hosts", {}).values():
166
        if not should_emit_host_on_target(data, target, group_defaults, host):
167
            continue
168
        item = merged(defaults, group_defaults, host)
169
        aliases = aliases_for_host(item)
170
        extra = {}
171
        if item.get("auth"):
172
            extra["auth"] = item["auth"]
173
        user = item.get("user")
174
        port = item.get("port")
175
        if company_managed_rule(data, target, aliases, user, port):
176
            user = None
177
            port = None
178
        lines.extend(host_block(aliases, item["hostname"], user, port, extra))
179
    for pattern, options in group.get("patterns", {}).items():
180
        lines.extend(pattern_block(pattern, options))
181
    return lines
182

            
183

            
184
def emit_groups(data, target=None, include_comments=True):
185
    lines = []
186
    defaults = data["defaults"]["final_host"]
187
    metadata_keys = {"description", "default_jump", "defaults", "patterns", "hosts"}
188
    queue = [(name, group) for name, group in data["groups"].items()]
189

            
190
    while queue:
191
        group_name, group = queue.pop(0)
192
        group_lines = emit_hosts_for_group(data, group, target, defaults)
193
        if group_lines:
194
            if include_comments:
195
                lines.extend([f"# Group: {group_name}", f"# Description: {group.get('description', '')}", ""])
196
            lines.extend(group_lines)
197

            
198
        for child_name, child in group.items():
199
            if child_name in metadata_keys or not isinstance(child, dict):
200
                continue
201
            if "hosts" in child:
202
                queue.append((f"{group_name}.{child_name}", child))
203
    return lines
204

            
205

            
206
def write(path: Path, lines):
207
    path.parent.mkdir(parents=True, exist_ok=True)
208
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
209

            
210

            
211
def generate(data, output_dir: Path):
212
    final_groups = emit_groups(data)
213

            
214
    client = generated_header("client")
215
    client.extend(emit_global_options(data))
216
    client.extend(emit_entrypoints(data))
217
    client.extend(emit_jumps(data))
218
    client.extend(final_groups)
219
    write(output_dir / "client.conf", client)
220

            
221
    is_jumper = generated_header("is-jumper")
222
    is_jumper.extend(emit_global_options(data))
223
    is_jumper.extend(emit_jumps(data))
224
    write(output_dir / "is-jumper.conf", is_jumper)
225

            
226
    for target in ("j1", "j2"):
227
        lines = generated_header(target, include_comments=False)
228
        if inherit_globals(data, target):
229
            pass
230
        else:
231
            lines.extend(emit_global_options(data, include_comments=False))
232
        lines.extend(emit_groups(data, target, include_comments=False))
233
        write(output_dir / f"{target}.conf", lines)
234

            
235

            
236
def main():
237
    parser = argparse.ArgumentParser()
238
    parser.add_argument("--inventory", default=ROOT / "inventory" / "hosts.yaml", type=Path)
239
    parser.add_argument("--output-dir", default=ROOT / "generated", type=Path)
240
    args = parser.parse_args()
241

            
242
    generate(load_inventory(args.inventory), args.output_dir)
243

            
244

            
245
if __name__ == "__main__":
246
    main()