diff --git a/changelog.d/18695.feature b/changelog.d/18695.feature new file mode 100644 index 0000000000..1481a27f23 --- /dev/null +++ b/changelog.d/18695.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1631f39e3..d086deab3f 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -590,5 +590,5 @@ class ExperimentalConfig(Config): self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False) # MSC4306: Thread Subscriptions - # (and MSC4308: sliding sync extension for thread subscriptions) + # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 174d02ab6b..c4905e63dd 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit = parse_integer_from_args(query, "limit", 0) + limit: Optional[int] = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( query, "include_all_networks", default=False diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 071a271ab7..255a041d0e 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -211,7 +211,7 @@ class SlidingSyncHandler: Args: sync_config: Sync configuration - to_token: The point in the stream to sync up to. + to_token: The latest point in the stream to sync up to. from_token: The point in the stream to sync from. Token of the end of the previous batch. May be `None` if this is the initial sync request. """ diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 077887ec32..25ee954b7f 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -27,7 +27,7 @@ from typing import ( cast, ) -from typing_extensions import assert_never +from typing_extensions import TypeAlias, assert_never from synapse.api.constants import AccountDataTypes, EduTypes from synapse.handlers.receipts import ReceiptEventSource @@ -40,6 +40,7 @@ from synapse.types import ( SlidingSyncStreamToken, StrCollection, StreamToken, + ThreadSubscriptionsToken, ) from synapse.types.handlers.sliding_sync import ( HaveSentRoomFlag, @@ -54,6 +55,13 @@ from synapse.util.async_helpers import ( gather_optional_coroutines, ) +_ThreadSubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription +) +_ThreadUnsubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription +) + if TYPE_CHECKING: from synapse.server import HomeServer @@ -68,6 +76,7 @@ class SlidingSyncExtensionHandler: self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled @trace async def get_extensions_response( @@ -93,7 +102,7 @@ class SlidingSyncExtensionHandler: actual_room_ids: The actual room IDs in the the Sliding Sync response. actual_room_response_map: A map of room ID to room results in the the Sliding Sync response. - to_token: The point in the stream to sync up to. + to_token: The latest point in the stream to sync up to. from_token: The point in the stream to sync from. """ @@ -156,18 +165,32 @@ class SlidingSyncExtensionHandler: from_token=from_token, ) + thread_subs_coro = None + if ( + sync_config.extensions.thread_subscriptions is not None + and self._enable_thread_subscriptions + ): + thread_subs_coro = self.get_thread_subscriptions_extension_response( + sync_config=sync_config, + thread_subscriptions_request=sync_config.extensions.thread_subscriptions, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, account_data_response, receipts_response, typing_response, + thread_subs_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, account_data_coro, receipts_coro, typing_coro, + thread_subs_coro, ) return SlidingSyncResult.Extensions( @@ -176,6 +199,7 @@ class SlidingSyncExtensionHandler: account_data=account_data_response, receipts=receipts_response, typing=typing_response, + thread_subscriptions=thread_subs_response, ) def find_relevant_room_ids_for_extension( @@ -877,3 +901,72 @@ class SlidingSyncExtensionHandler: return SlidingSyncResult.Extensions.TypingExtension( room_id_to_typing_map=room_id_to_typing_map, ) + + async def get_thread_subscriptions_extension_response( + self, + sync_config: SlidingSyncConfig, + thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]: + """Handle Thread Subscriptions extension (MSC4308) + + Args: + sync_config: Sync configuration + thread_subscriptions_request: The thread_subscriptions extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + + Returns: + the response (None if empty or thread subscriptions are disabled) + """ + if not thread_subscriptions_request.enabled: + return None + + limit = thread_subscriptions_request.limit + + if from_token: + from_stream_id = from_token.stream_token.thread_subscriptions_key + else: + from_stream_id = StreamToken.START.thread_subscriptions_key + + to_stream_id = to_token.thread_subscriptions_key + + updates = await self.store.get_latest_updated_thread_subscriptions_for_user( + user_id=sync_config.user.to_string(), + from_id=from_stream_id, + to_id=to_stream_id, + limit=limit, + ) + + if len(updates) == 0: + return None + + subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {} + unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {} + for stream_id, room_id, thread_root_id, subscribed, automatic in updates: + if subscribed: + subscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + _ThreadSubscription( + automatic=automatic, + bump_stamp=stream_id, + ) + ) + else: + unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + _ThreadUnsubscription(bump_stamp=stream_id) + ) + + prev_batch = None + if len(updates) == limit: + # Tell the client about a potential gap where there may be more + # thread subscriptions for it to backpaginate. + # We subtract one because the 'later in the stream' bound is inclusive, + # and we already saw the element at index 0. + prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1) + + return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension( + subscribed=subscribed_threads, + unsubscribed=unsubscribed_threads, + prev_batch=prev_batch, + ) diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index bda4342949..d56c915e0a 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -9,7 +9,7 @@ from synapse.storage.databases.main.thread_subscriptions import ( AutomaticSubscriptionConflicted, ThreadSubscription, ) -from synapse.types import EventOrderings, UserID +from synapse.types import EventOrderings, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -22,6 +22,7 @@ class ThreadSubscriptionsHandler: self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() + self._notifier = hs.get_notifier() async def get_thread_subscription_settings( self, @@ -132,6 +133,15 @@ class ThreadSubscriptionsHandler: errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, ) + if outcome is not None: + # wake up user streams (e.g. sliding sync) on the same worker + self._notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + # outcome is a stream_id + outcome, + users=[user_id.to_string()], + ) + return outcome async def unsubscribe_user_from_thread( @@ -162,8 +172,19 @@ class ThreadSubscriptionsHandler: logger.info("rejecting thread subscriptions change (thread not accessible)") raise NotFoundError("No such thread root") - return await self.store.unsubscribe_user_from_thread( + outcome = await self.store.unsubscribe_user_from_thread( user_id.to_string(), event.room_id, thread_root_event_id, ) + + if outcome is not None: + # wake up user streams (e.g. sliding sync) on the same worker + self._notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + # outcome is a stream_id + outcome, + users=[user_id.to_string()], + ) + + return outcome diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 47d8bd5eaf..69bdce2b83 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -130,6 +130,16 @@ def parse_integer( return parse_integer_from_args(args, name, default, required, negative) +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: int, + required: Literal[False] = False, + negative: bool = False, +) -> int: ... + + @overload def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], diff --git a/synapse/notifier.py b/synapse/notifier.py index 7782c9ca65..e684df4866 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -532,6 +532,7 @@ class Notifier: StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, + StreamKeyType.THREAD_SUBSCRIPTIONS, ], new_token: int, users: Optional[Collection[Union[str, UserID]]] = None, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index ee9250cf7d..7a86b2e65e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -44,6 +44,7 @@ from synapse.replication.tcp.streams import ( UnPartialStatedEventStream, UnPartialStatedRoomStream, ) +from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -255,6 +256,12 @@ class ReplicationDataHandler: self._state_storage_controller.notify_event_un_partial_stated( row.event_id ) + elif stream_name == ThreadSubscriptionsStream.NAME: + self.notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + token, + users=[row.user_id for row in rows], + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 6f2f6642be..c424ca5325 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -23,6 +23,8 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +import attr + from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection @@ -632,12 +634,21 @@ class SyncRestServlet(RestServlet): class SlidingSyncRestServlet(RestServlet): """ - API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a + API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived + from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a subset (sliding window) of rooms, state, and timeline events (just what they need) in order to bootstrap quickly and subscribe to only what the client cares about. Because the client can specify what it cares about, we can respond quickly and skip all of the work we would normally have to do with a sync v2 response. + Extensions of various features are defined in: + - to-device messaging (MSC3885) + - end-to-end encryption (MSC3884) + - typing notifications (MSC3961) + - receipts (MSC3960) + - account data (MSC3959) + - thread subscriptions (MSC4308) + Request query parameters: timeout: How long to wait for new events in milliseconds. pos: Stream position token when asking for incremental deltas. @@ -1074,9 +1085,48 @@ class SlidingSyncRestServlet(RestServlet): "rooms": extensions.typing.room_id_to_typing_map, } + # excludes both None and falsy `thread_subscriptions` + if extensions.thread_subscriptions: + serialized_extensions["io.element.msc4308.thread_subscriptions"] = ( + _serialise_thread_subscriptions(extensions.thread_subscriptions) + ) + return serialized_extensions +def _serialise_thread_subscriptions( + thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension, +) -> JsonDict: + out: JsonDict = {} + + if thread_subscriptions.subscribed: + out["subscribed"] = { + room_id: { + thread_root_id: attr.asdict( + change, filter=lambda _attr, v: v is not None + ) + for thread_root_id, change in room_threads.items() + } + for room_id, room_threads in thread_subscriptions.subscribed.items() + } + + if thread_subscriptions.unsubscribed: + out["unsubscribed"] = { + room_id: { + thread_root_id: attr.asdict( + change, filter=lambda _attr, v: v is not None + ) + for thread_root_id, change in room_threads.items() + } + for room_id, room_threads in thread_subscriptions.unsubscribed.items() + } + + if thread_subscriptions.prev_batch: + out["prev_batch"] = thread_subscriptions.prev_batch.to_string() + + return out + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index 4e7b5d06db..039aba1721 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,21 +1,39 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import attr +from typing_extensions import TypeAlias from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_and_validate_json_object_from_request, + parse_integer, + parse_string, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, RoomID +from synapse.types import ( + JsonDict, + RoomID, + SlidingSyncStreamToken, + ThreadSubscriptionsToken, +) +from synapse.types.handlers.sliding_sync import SlidingSyncResult from synapse.types.rest import RequestBodyModel from synapse.util.pydantic_models import AnyEventId if TYPE_CHECKING: from synapse.server import HomeServer +_ThreadSubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription +) +_ThreadUnsubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription +) + class ThreadSubscriptionsRestServlet(RestServlet): PATTERNS = client_patterns( @@ -100,6 +118,130 @@ class ThreadSubscriptionsRestServlet(RestServlet): return HTTPStatus.OK, {} +class ThreadSubscriptionsPaginationRestServlet(RestServlet): + PATTERNS = client_patterns( + "/io.element.msc4308/thread_subscriptions$", + unstable=True, + releases=(), + ) + CATEGORY = "Thread Subscriptions requests (unstable)" + + # Maximum number of thread subscriptions to return in one request. + MAX_LIMIT = 512 + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastores().main + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = min( + parse_integer(request, "limit", default=100, negative=False), + ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT, + ) + from_end_opt = parse_string(request, "from", required=False) + to_start_opt = parse_string(request, "to", required=False) + _direction = parse_string(request, "dir", required=True, allowed_values=("b",)) + + if limit <= 0: + # condition needed because `negative=False` still allows 0 + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "limit must be greater than 0", + errcode=Codes.INVALID_PARAM, + ) + + if from_end_opt is not None: + try: + # because of backwards pagination, the `from` token is actually the + # bound closest to the end of the stream + end_stream_id = ThreadSubscriptionsToken.from_string( + from_end_opt + ).stream_id + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "`from` is not a valid token", + errcode=Codes.INVALID_PARAM, + ) + else: + end_stream_id = self.store.get_max_thread_subscriptions_stream_id() + + if to_start_opt is not None: + # because of backwards pagination, the `to` token is actually the + # bound closest to the start of the stream + try: + start_stream_id = ThreadSubscriptionsToken.from_string( + to_start_opt + ).stream_id + except ValueError: + # we also accept sliding sync `pos` tokens on this parameter + try: + sliding_sync_pos = await SlidingSyncStreamToken.from_string( + self.store, to_start_opt + ) + start_stream_id = ( + sliding_sync_pos.stream_token.thread_subscriptions_key + ) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "`to` is not a valid token", + errcode=Codes.INVALID_PARAM, + ) + else: + # the start of time is ID 1; the lower bound is exclusive though + start_stream_id = 0 + + subscriptions = ( + await self.store.get_latest_updated_thread_subscriptions_for_user( + requester.user.to_string(), + from_id=start_stream_id, + to_id=end_stream_id, + limit=limit, + ) + ) + + subscribed_threads: Dict[str, Dict[str, JsonDict]] = {} + unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {} + for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions: + if subscribed: + subscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + attr.asdict( + _ThreadSubscription( + automatic=automatic, + bump_stamp=stream_id, + ) + ) + ) + else: + unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id)) + ) + + result: JsonDict = {} + if subscribed_threads: + result["subscribed"] = subscribed_threads + if unsubscribed_threads: + result["unsubscribed"] = unsubscribed_threads + + if len(subscriptions) == limit: + # We hit the limit, so there might be more entries to return. + # Generate a new token that has moved backwards, ready for the next + # request. + min_returned_stream_id, _, _, _, _ = subscriptions[0] + result["end"] = ThreadSubscriptionsToken( + # We subtract one because the 'later in the stream' bound is inclusive, + # and we already saw the element at index 0. + stream_id=min_returned_stream_id - 1 + ).to_string() + + return HTTPStatus.OK, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc4306_enabled: ThreadSubscriptionsRestServlet(hs).register(http_server) + ThreadSubscriptionsPaginationRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 5edac56ec3..ea746e0511 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -53,7 +53,7 @@ from synapse.storage.databases.main.stream import ( generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -316,17 +316,8 @@ class RelationsWorkerStore(SQLBaseStore): StreamKeyType.ROOM, next_key ) else: - next_token = StreamToken( - room_key=next_key, - presence_key=0, - typing_key=0, - receipt_key=MultiWriterStreamToken(stream=0), - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=MultiWriterStreamToken(stream=0), - groups_key=0, - un_partial_stated_rooms_key=0, + next_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, next_key ) return events[:limit], next_token diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index 6a62b11d1e..72ec8e6b90 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -492,7 +492,7 @@ class PerConnectionStateDB: """An equivalent to `PerConnectionState` that holds data in a format stored in the DB. - The principle difference is that the tokens for the different streams are + The principal difference is that the tokens for the different streams are serialized to strings. When persisting this *only* contains updates to the state. diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 24a99cf449..50084887a4 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -505,6 +505,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): """ return self._thread_subscriptions_id_gen.get_current_token() + def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._thread_subscriptions_id_gen + async def get_updated_thread_subscriptions( self, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: @@ -538,34 +541,52 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): get_updated_thread_subscriptions_txn, ) - async def get_updated_thread_subscriptions_for_user( + async def get_latest_updated_thread_subscriptions_for_user( self, user_id: str, *, from_id: int, to_id: int, limit: int - ) -> List[Tuple[int, str, str]]: - """Get updates to thread subscriptions for a specific user. + ) -> List[Tuple[int, str, str, bool, Optional[bool]]]: + """Get the latest updates to thread subscriptions for a specific user. Args: user_id: The ID of the user from_id: The starting stream ID (exclusive) to_id: The ending stream ID (inclusive) limit: The maximum number of rows to return + If there are too many rows to return, rows from the start (closer to `from_id`) + will be omitted. Returns: - A list of (stream_id, room_id, thread_root_event_id) tuples. + A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples. + The row with lowest `stream_id` is the first row. """ def get_updated_thread_subscriptions_for_user_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str]]: + ) -> List[Tuple[int, str, str, bool, Optional[bool]]]: sql = """ - SELECT stream_id, room_id, event_id - FROM thread_subscriptions - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + WITH the_updates AS ( + SELECT stream_id, room_id, event_id, subscribed, automatic + FROM thread_subscriptions + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id DESC + LIMIT ? + ) + SELECT stream_id, room_id, event_id, subscribed, automatic + FROM the_updates ORDER BY stream_id ASC - LIMIT ? """ txn.execute(sql, (user_id, from_id, to_id, limit)) - return [(row[0], row[1], row[2]) for row in txn] + return [ + ( + stream_id, + room_id, + event_id, + # SQLite integer to boolean conversions + bool(subscribed), + bool(automatic) if subscribed else None, + ) + for (stream_id, room_id, event_id, subscribed, automatic) in txn + ] return await self.db_pool.runInteraction( "get_updated_thread_subscriptions_for_user", diff --git a/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres b/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres new file mode 100644 index 0000000000..d327d1e165 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres @@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the +-- stream sequence. +-- This makes last_value of the sequence point to a position that will not get later +-- returned by nextval. +-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.) +SELECT nextval('thread_subscriptions_sequence'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a15a161ce8..1b7c5dac7a 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): Warning: Streams using this generator start at ID 2, because ID 1 is always assumed to have been 'seen as persisted'. Unclear if this extant behaviour is desirable for some reason. - When creating a new sequence for a new stream, - it will be necessary to use `START WITH 2`. + When creating a new sequence for a new stream, it will be necessary to advance it + so that position 1 is consumed. + DO NOT USE `START WITH 2` FOR THIS PURPOSE: + see https://github.com/element-hq/synapse/issues/18712 + Instead, use `SELECT nextval('sequence_name');` immediately after the + `CREATE SEQUENCE` statement. Args: db_conn diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4534068e7c..1e4bebe46d 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -33,7 +33,6 @@ from synapse.logging.opentracing import trace from synapse.streams import EventSource from synapse.types import ( AbstractMultiWriterStreamToken, - MultiWriterStreamToken, StreamKeyType, StreamToken, ) @@ -84,6 +83,7 @@ class EventSources: un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token( self._instance_name ) + thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -97,6 +97,7 @@ class EventSources: # Groups key is unused. groups_key=0, un_partial_stated_rooms_key=un_partial_stated_rooms_key, + thread_subscriptions_key=thread_subscriptions_key, ) return token @@ -123,6 +124,7 @@ class EventSources: StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(), StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), + StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): @@ -195,16 +197,7 @@ class EventSources: Returns: The current token for pagination. """ - token = StreamToken( - room_key=await self.sources.room.get_current_key_for_room(room_id), - presence_key=0, - typing_key=0, - receipt_key=MultiWriterStreamToken(stream=0), - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=MultiWriterStreamToken(stream=0), - groups_key=0, - un_partial_stated_rooms_key=0, + return StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, + await self.sources.room.get_current_key_for_room(room_id), ) - return token diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 943f211b11..2d5b07ab8f 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -996,6 +996,7 @@ class StreamKeyType(Enum): TO_DEVICE = "to_device_key" DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" + THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1003,7 +1004,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -1016,6 +1017,7 @@ class StreamToken: 8. `device_list_key`: `265584` 9. `groups_key`: `1` (note that this key is now unused) 10. `un_partial_stated_rooms_key`: `379` + 11. `thread_subscriptions_key`: 4242 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1074,6 +1076,7 @@ class StreamToken: # Note that the groups key is no longer used and may have bogus values. groups_key: int un_partial_stated_rooms_key: int + thread_subscriptions_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1101,6 +1104,7 @@ class StreamToken: device_list_key, groups_key, un_partial_stated_rooms_key, + thread_subscriptions_key, ) = keys return cls( @@ -1116,6 +1120,7 @@ class StreamToken: ), groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), + thread_subscriptions_key=int(thread_subscriptions_key), ) except CancelledError: raise @@ -1138,6 +1143,7 @@ class StreamToken: # if additional tokens are added. str(self.groups_key), str(self.un_partial_stated_rooms_key), + str(self.thread_subscriptions_key), ] ) @@ -1202,6 +1208,7 @@ class StreamToken: StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, + StreamKeyType.THREAD_SUBSCRIPTIONS, ], ) -> int: ... @@ -1257,7 +1264,8 @@ class StreamToken: f"typing: {self.typing_key}, receipt: {self.receipt_key}, " f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " - f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})" + f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," + f"thread_subscriptions: {self.thread_subscriptions_key})" ) @@ -1272,6 +1280,7 @@ StreamToken.START = StreamToken( device_list_key=MultiWriterStreamToken(stream=0), groups_key=0, un_partial_stated_rooms_key=0, + thread_subscriptions_key=0, ) @@ -1318,6 +1327,27 @@ class SlidingSyncStreamToken: return f"{self.connection_position}/{stream_token_str}" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadSubscriptionsToken: + """ + Token for a position in the thread subscriptions stream. + + Format: `ts` + """ + + stream_id: int + + @staticmethod + def from_string(s: str) -> "ThreadSubscriptionsToken": + if not s.startswith("ts"): + raise ValueError("thread subscription token must start with `ts`") + + return ThreadSubscriptionsToken(stream_id=int(s[2:])) + + def to_string(self) -> str: + return f"ts{self.stream_id}" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class PersistedPosition: """Position of a newly persisted row with instance that persisted it.""" diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 3ebd334a6d..b7bc565464 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -50,6 +50,7 @@ from synapse.types import ( SlidingSyncStreamToken, StrCollection, StreamToken, + ThreadSubscriptionsToken, UserID, ) from synapse.types.rest.client import SlidingSyncBody @@ -357,11 +358,50 @@ class SlidingSyncResult: def __bool__(self) -> bool: return bool(self.room_id_to_typing_map) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadSubscriptionsExtension: + """The Thread Subscriptions extension (MSC4308) + + Attributes: + subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions + unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions + prev_batch: if present, there is a gap and the client can use this token to backpaginate + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadSubscription: + # always present when `subscribed` + automatic: Optional[bool] + + # the same as our stream_id; useful for clients to resolve + # race conditions locally + bump_stamp: int + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadUnsubscription: + # the same as our stream_id; useful for clients to resolve + # race conditions locally + bump_stamp: int + + # room_id -> event_id (of thread root) -> the subscription change + subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]] + # room_id -> event_id (of thread root) -> the unsubscription + unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]] + prev_batch: Optional[ThreadSubscriptionsToken] + + def __bool__(self) -> bool: + return ( + bool(self.subscribed) + or bool(self.unsubscribed) + or bool(self.prev_batch) + ) + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None + thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None def __bool__(self) -> bool: return bool( @@ -370,6 +410,7 @@ class SlidingSyncResult: or self.account_data or self.receipts or self.typing + or self.thread_subscriptions ) next_pos: SlidingSyncStreamToken diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index c739bd16b0..11d7e59b43 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from synapse._pydantic_compat import ( Extra, + Field, StrictBool, StrictInt, StrictStr, @@ -364,11 +365,25 @@ class SlidingSyncBody(RequestBodyModel): # Process all room subscriptions defined in the Room Subscription API. (This is the default.) rooms: Optional[List[StrictStr]] = ["*"] + class ThreadSubscriptionsExtension(RequestBodyModel): + """The Thread Subscriptions extension (MSC4308) + + Attributes: + enabled + limit: maximum number of subscription changes to return (default 100) + """ + + enabled: Optional[StrictBool] = False + limit: StrictInt = 100 + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None + thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( + alias="io.element.msc4308.thread_subscriptions" + ) conn_id: Optional[StrictStr] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index e596e1ed20..c21b7887f9 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -347,6 +347,7 @@ T2 = TypeVar("T2") T3 = TypeVar("T3") T4 = TypeVar("T4") T5 = TypeVar("T5") +T6 = TypeVar("T6") @overload @@ -461,6 +462,23 @@ async def gather_optional_coroutines( ) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + Optional[Coroutine[Any, Any, T6]], + ] + ], +) -> Tuple[ + Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6] +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], ) -> Tuple[Optional[T1], ...]: diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index b98c53891c..ee5d0419ab 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py new file mode 100644 index 0000000000..775c4f96c9 --- /dev/null +++ b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py @@ -0,0 +1,497 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging +from http import HTTPStatus +from typing import List, Optional, Tuple, cast + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, room, sync, thread_subscriptions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase + +logger = logging.getLogger(__name__) + + +# The name of the extension. Currently unstable-prefixed. +EXT_NAME = "io.element.msc4308.thread_subscriptions" + + +class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase): + """ + Test the thread subscriptions extension in the Sliding Sync API. + """ + + maxDiff = None + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + thread_subscriptions.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4306_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + super().prepare(reactor, clock, hs) + + def test_no_data_initial_sync(self) -> None: + """ + Test enabling thread subscriptions extension during initial sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertNotIn(EXT_NAME, response_body["extensions"]) + + def test_no_data_incremental_sync(self) -> None: + """ + Test enabling thread subscriptions extension during incremental sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + initial_sync_body: JsonDict = { + "lists": {}, + } + + # Initial sync + response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok) + + # Incremental sync with extension enabled + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert + self.assertNotIn( + EXT_NAME, + response_body["extensions"], + response_body, + ) + + def test_thread_subscription_initial_sync(self) -> None: + """ + Test thread subscriptions appear in initial sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: { + "automatic": False, + "bump_stamp": base + 1, + } + } + } + }, + ) + + def test_thread_subscription_incremental_sync(self) -> None: + """ + Test new thread subscriptions appear in incremental sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Initial sync + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + logger.info("Synced to: %r, now subscribing to thread", sync_pos) + + # Subscribe + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + logger.info("Synced to: %r", sync_pos) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: { + "automatic": False, + "bump_stamp": base + 1, + } + } + } + }, + ) + + def test_unsubscribe_from_thread(self) -> None: + """ + Test unsubscribing from a thread. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Assert: Subscription present + self.assertIn(EXT_NAME, response_body["extensions"]) + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: {"automatic": False, "bump_stamp": base + 1} + } + } + }, + ) + + # Unsubscribe + self._unsubscribe_from_thread(user1_id, room_id, thread_root_id) + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Unsubscription present + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}}, + ) + + def test_multiple_thread_subscriptions(self) -> None: + """ + Test handling of multiple thread subscriptions. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread roots + thread_root_resp1 = self.helper.send( + room_id, body="Thread root 1", tok=user1_tok + ) + thread_root_id1 = thread_root_resp1["event_id"] + thread_root_resp2 = self.helper.send( + room_id, body="Thread root 2", tok=user1_tok + ) + thread_root_id2 = thread_root_resp2["event_id"] + thread_root_resp3 = self.helper.send( + room_id, body="Thread root 3", tok=user1_tok + ) + thread_root_id3 = thread_root_resp3["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Subscribe to threads + self._subscribe_to_thread(user1_id, room_id, thread_root_id1) + self._subscribe_to_thread(user1_id, room_id, thread_root_id2) + self._subscribe_to_thread(user1_id, room_id, thread_root_id3) + + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id1: { + "automatic": False, + "bump_stamp": base + 1, + }, + thread_root_id2: { + "automatic": False, + "bump_stamp": base + 2, + }, + thread_root_id3: { + "automatic": False, + "bump_stamp": base + 3, + }, + } + } + }, + ) + + def test_limit_parameter(self) -> None: + """ + Test limit parameter in thread subscriptions extension. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 5 thread roots and subscribe to each + thread_root_ids = [] + for i in range(5): + thread_root_resp = self.helper.send( + room_id, body=f"Thread root {i}", tok=user1_tok + ) + thread_root_ids.append(thread_root_resp["event_id"]) + self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1]) + + sync_body = { + "lists": {}, + "extensions": {EXT_NAME: {"enabled": True, "limit": 3}}, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + thread_subscriptions = response_body["extensions"][EXT_NAME] + self.assertEqual( + len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions + ) + + def test_limit_and_companion_backpagination(self) -> None: + """ + Create 1 thread subscription, do a sync, create 4 more, + then sync with a limit of 2 and fill in the gap + using the companion /thread_subscriptions endpoint. + """ + + thread_root_ids: List[str] = [] + + def make_subscription() -> None: + thread_root_resp = self.helper.send( + room_id, body="Some thread root", tok=user1_tok + ) + thread_root_ids.append(thread_root_resp["event_id"]) + self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1]) + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Make our first subscription + make_subscription() + + # Sync for the first time + sync_body = { + "lists": {}, + "extensions": {EXT_NAME: {"enabled": True, "limit": 2}}, + } + + sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok) + + thread_subscriptions = sync_resp["extensions"][EXT_NAME] + self.assertEqual( + thread_subscriptions["subscribed"], + { + room_id: { + thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1}, + } + }, + ) + + # Get our pos for the next sync + first_sync_pos = sync_resp["pos"] + + # Create 5 more thread subscriptions and subscribe to each + for _ in range(5): + make_subscription() + + # Now sync again. Our limit is 2, + # so we should get the latest 2 subscriptions, + # with a gap of 3 more subscriptions in the middle + sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos) + + thread_subscriptions = sync_resp["extensions"][EXT_NAME] + self.assertEqual( + thread_subscriptions["subscribed"], + { + room_id: { + thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5}, + thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6}, + } + }, + ) + # 1st backpagination: expecting a page with 2 subscriptions + page, end_tok = self._do_backpaginate( + from_tok=thread_subscriptions["prev_batch"], + to_tok=first_sync_pos, + limit=2, + access_token=user1_tok, + ) + self.assertIsNotNone(end_tok, "backpagination should continue") + self.assertEqual( + page["subscribed"], + { + room_id: { + thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3}, + thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4}, + } + }, + ) + + # 2nd backpagination: expecting a page with only 1 subscription + # and no other token for further backpagination + assert end_tok is not None + page, end_tok = self._do_backpaginate( + from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok + ) + self.assertIsNone(end_tok, "backpagination should have finished") + self.assertEqual( + page["subscribed"], + { + room_id: { + thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2}, + } + }, + ) + + def _do_backpaginate( + self, *, from_tok: str, to_tok: str, limit: int, access_token: str + ) -> Tuple[JsonDict, Optional[str]]: + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4308/thread_subscriptions" + f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b", + access_token=access_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + body = channel.json_body + return body, cast(Optional[str], body.get("end")) + + def _subscribe_to_thread( + self, user_id: str, room_id: str, thread_root_id: str + ) -> None: + """ + Helper method to subscribe a user to a thread. + """ + self.get_success( + self.store.subscribe_user_to_thread( + user_id=user_id, + room_id=room_id, + thread_root_event_id=thread_root_id, + automatic_event_orderings=None, + ) + ) + + def _unsubscribe_from_thread( + self, user_id: str, room_id: str, thread_root_id: str + ) -> None: + """ + Helper method to unsubscribe a user from a thread. + """ + self.get_success( + self.store.unsubscribe_user_from_thread( + user_id=user_id, + room_id=room_id, + thread_root_event_id=thread_root_id, + ) + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 24a28fbdd2..d3b5e26132 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 2a5c440cf4..2ce369247f 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -189,19 +189,19 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self._subscribe(self.other_thread_root_id, automatic_event_orderings=None) subscriptions = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=50, limit=50, ) ) - min_id = min(id for (id, _, _) in subscriptions) + min_id = min(id for (id, _, _, _, _) in subscriptions) self.assertEqual( subscriptions, [ - (min_id, self.room_id, self.thread_root_id), - (min_id + 1, self.room_id, self.other_thread_root_id), + (min_id, self.room_id, self.thread_root_id, True, True), + (min_id + 1, self.room_id, self.other_thread_root_id, True, False), ], ) @@ -212,7 +212,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Check user has no subscriptions subscriptions = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=50, @@ -280,20 +280,22 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Get updates for main user updates = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) - self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) + self.assertEqual( + updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)] + ) # Get updates for other user updates = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10 ) ) self.assertEqual( - updates, [(stream_id2, self.room_id, self.other_thread_root_id)] + updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)] ) def test_should_skip_autosubscription_after_unsubscription(self) -> None: