#!/usr/bin/env python # # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2022-2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # See the GNU Affero General Public License for more details: # . # # Originally licensed under the Apache License, Version 2.0: # . # # [This file includes modifications made by New Vector Limited] # # import argparse import logging import re from collections import defaultdict from dataclasses import dataclass from typing import Iterable, Pattern import yaml from synapse.config.homeserver import HomeServerConfig from synapse.federation.transport.server import ( TransportLayerServer, register_servlets as register_federation_servlets, ) from synapse.http.server import HttpServer, ServletCallback from synapse.rest import ClientRestResource from synapse.rest.key.v2 import RemoteKey from synapse.server import HomeServer from synapse.storage import DataStore logger = logging.getLogger("generate_workers_map") class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore def __init__(self, config: HomeServerConfig, worker_app: str | None) -> None: super().__init__(config.server.server_name, config=config) self.config.worker.worker_app = worker_app GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)") @dataclass class EndpointDescription: """ Describes an endpoint and how it should be routed. """ # The servlet class that handles this endpoint servlet_class: object # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet # class. category: str | None # TODO: # - does it need to be routed based on a stream writer config? # - does it benefit from any optimised, but optional, routing? # - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does # it go in? class EnumerationResource(HttpServer): """ Accepts servlet registrations for the purposes of building up a description of all endpoints. """ def __init__(self, is_worker: bool) -> None: self.registrations: dict[tuple[str, str], EndpointDescription] = {} self._is_worker = is_worker def register_paths( self, method: str, path_patterns: Iterable[Pattern], callback: ServletCallback, servlet_classname: str, ) -> None: # federation servlet callbacks are wrapped, so unwrap them. callback = getattr(callback, "__wrapped__", callback) # fish out the servlet class servlet_class = callback.__self__.__class__ # type: ignore if self._is_worker and method in getattr( servlet_class, "WORKERS_DENIED_METHODS", () ): # This endpoint would cause an error if called on a worker, so pretend it # was never registered! return sd = EndpointDescription( servlet_class=servlet_class, category=getattr(servlet_class, "CATEGORY", None), ) for pat in path_patterns: self.registrations[(method, pat.pattern)] = sd def get_registered_paths_for_hs( hs: HomeServer, ) -> dict[tuple[str, str], EndpointDescription]: """ Given a homeserver, get all registered endpoints and their descriptions. """ enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None) ClientRestResource.register_servlets(enumerator, hs) federation_server = TransportLayerServer(hs) # we can't use `federation_server.register_servlets` but this line does the # same thing, only it uses this enumerator register_federation_servlets( federation_server.hs, resource=enumerator, ratelimiter=federation_server.ratelimiter, authenticator=federation_server.authenticator, servlet_groups=federation_server.servlet_groups, ) # the key server endpoints are separate again RemoteKey(hs).register(enumerator) return enumerator.registrations def get_registered_paths_for_default( worker_app: str | None, base_config: HomeServerConfig ) -> dict[tuple[str, str], EndpointDescription]: """ Given the name of a worker application and a base homeserver configuration, returns: Dict from (method, path) to EndpointDescription TODO Don't require passing in a config """ hs = MockHomeserver(base_config, worker_app) # TODO We only do this to avoid an error, but don't need the database etc hs.setup() registered_paths = get_registered_paths_for_hs(hs) # NOTE: a more robust implementation would properly shutdown/cleanup each server # to avoid resource buildup. # However, the call to `shutdown` is `async` so it would require additional complexity here. # We are intentionally skipping this cleanup because this is a short-lived, one-off # utility script where the simpler approach is sufficient and we shouldn't run into # any resource buildup issues. return registered_paths def elide_http_methods_if_unconflicting( registrations: dict[tuple[str, str], EndpointDescription], all_possible_registrations: dict[tuple[str, str], EndpointDescription], ) -> dict[tuple[str, str], EndpointDescription]: """ Elides HTTP methods (by replacing them with `*`) if all possible registered methods can be handled by the worker whose registration map is `registrations`. i.e. the only endpoints left with methods (other than `*`) should be the ones where the worker can't handle all possible methods for that path. """ def paths_to_methods_dict( methods_and_paths: Iterable[tuple[str, str]], ) -> dict[str, set[str]]: """ Given (method, path) pairs, produces a dict from path to set of methods available at that path. """ result: dict[str, set[str]] = {} for method, path in methods_and_paths: result.setdefault(path, set()).add(method) return result all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations) reg_methods = paths_to_methods_dict(registrations) output = {} for path, handleable_methods in reg_methods.items(): if handleable_methods == all_possible_reg_methods[path]: any_method = next(iter(handleable_methods)) # TODO This assumes that all methods have the same servlet. # I suppose that's possibly dubious? output[("*", path)] = registrations[(any_method, path)] else: for method in handleable_methods: output[(method, path)] = registrations[(method, path)] return output def simplify_path_regexes( registrations: dict[tuple[str, str], EndpointDescription], ) -> dict[tuple[str, str], EndpointDescription]: """ Simplify all the path regexes for the dict of endpoint descriptions, so that we don't use the Python-specific regex extensions (and also to remove needlessly specific detail). """ def simplify_path_regex(path: str) -> str: """ Given a regex pattern, replaces all named capturing groups (e.g. `(?Pxyz)`) with a simpler version available in more common regex dialects (e.g. `.*`). """ # TODO it's hard to choose between these two; # `.*` is a vague simplification # return GROUP_PATTERN.sub(r"\1", path) return GROUP_PATTERN.sub(r".*", path) return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()} def main() -> None: parser = argparse.ArgumentParser( description=( "Updates a synapse database to the latest schema and optionally runs background updates" " on it." ) ) parser.add_argument("-v", action="store_true") parser.add_argument( "--config-path", type=argparse.FileType("r"), required=True, help="Synapse configuration file", ) args = parser.parse_args() # TODO # logging.basicConfig(**logging_config) # Load, process and sanity-check the config. hs_config = yaml.safe_load(args.config_path) config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") master_paths = get_registered_paths_for_default(None, config) worker_paths = get_registered_paths_for_default( "synapse.app.generic_worker", config ) all_paths = {**master_paths, **worker_paths} elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths) elide_http_methods_if_unconflicting(master_paths, all_paths) # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT categories_to_methods_and_paths: dict[ str | None, dict[tuple[str, str], EndpointDescription] ] = defaultdict(dict) for (method, path), desc in elided_worker_paths.items(): categories_to_methods_and_paths[desc.category][method, path] = desc for category, contents in categories_to_methods_and_paths.items(): print_category(category, contents) def print_category( category_name: str | None, elided_worker_paths: dict[tuple[str, str], EndpointDescription], ) -> None: """ Prints out a category, in documentation page style. Example: ``` # Category name /path/xyz GET /path/abc ``` """ if category_name: print(f"# {category_name}") else: print("# (Uncategorised requests)") for ln in sorted( p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*" ): print(ln) print() for ln in sorted( f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*" ): print(ln) print() if __name__ == "__main__": main()