Add experimental support for MSC4308: Thread Subscriptions extension to Sliding Sync when MSC4306 and MSC4186 are enabled. (#18695)
Some checks are pending
Build docker images / Build and push image for ${{ matrix.platform }} (linux/amd64, ubuntu-24.04, linux-amd64) (push) Waiting to run
Build docker images / Build and push image for ${{ matrix.platform }} (linux/arm64, ubuntu-24.04-arm, linux-arm64) (push) Waiting to run
Build docker images / Push merged images to ${{ matrix.repository }} (docker.io/matrixdotorg/synapse) (push) Blocked by required conditions
Build docker images / Push merged images to ${{ matrix.repository }} (ghcr.io/element-hq/synapse) (push) Blocked by required conditions
Deploy the documentation / Calculate variables for GitHub Pages deployment (push) Waiting to run
Deploy the documentation / GitHub Pages (push) Blocked by required conditions
Build release artifacts / Calculate list of debian distros (push) Waiting to run
Build release artifacts / Build .deb packages (push) Blocked by required conditions
Build release artifacts / Build wheels on ${{ matrix.os }} (${{ startsWith(github.ref, 'refs/pull/') }}, macos-13) (push) Waiting to run
Build release artifacts / Build wheels on ${{ matrix.os }} (${{ startsWith(github.ref, 'refs/pull/') }}, macos-14) (push) Waiting to run
Build release artifacts / Build wheels on ${{ matrix.os }} (${{ startsWith(github.ref, 'refs/pull/') }}, ubuntu-24.04) (push) Waiting to run
Build release artifacts / Build wheels on ${{ matrix.os }} (${{ startsWith(github.ref, 'refs/pull/') }}, ubuntu-24.04-arm) (push) Waiting to run
Build release artifacts / Build sdist (push) Waiting to run
Build release artifacts / Attach assets to release (push) Blocked by required conditions
Schema / Ensure Synapse config schema is valid (push) Waiting to run
Schema / Ensure generated documentation is up-to-date (push) Waiting to run
Tests / changes (push) Waiting to run
Tests / check-sampleconfig (push) Blocked by required conditions
Tests / check-schema-delta (push) Blocked by required conditions
Tests / check-lockfile (push) Waiting to run
Tests / lint (push) Blocked by required conditions
Tests / Typechecking (push) Blocked by required conditions
Tests / lint-crlf (push) Waiting to run
Tests / lint-newsfile (push) Waiting to run
Tests / lint-pydantic (push) Blocked by required conditions
Tests / lint-clippy (push) Blocked by required conditions
Tests / lint-clippy-nightly (push) Blocked by required conditions
Tests / lint-rust (push) Blocked by required conditions
Tests / lint-rustfmt (push) Blocked by required conditions
Tests / lint-readme (push) Blocked by required conditions
Tests / linting-done (push) Blocked by required conditions
Tests / calculate-test-jobs (push) Blocked by required conditions
Tests / trial (push) Blocked by required conditions
Tests / trial-olddeps (push) Blocked by required conditions
Tests / trial-pypy (all, pypy-3.9) (push) Blocked by required conditions
Tests / sytest (push) Blocked by required conditions
Tests / export-data (push) Blocked by required conditions
Tests / portdb (13, 3.9) (push) Blocked by required conditions
Tests / portdb (17, 3.13) (push) Blocked by required conditions
Tests / complement (monolith, Postgres) (push) Blocked by required conditions
Tests / complement (monolith, SQLite) (push) Blocked by required conditions
Tests / complement (workers, Postgres) (push) Blocked by required conditions
Tests / cargo-test (push) Blocked by required conditions
Tests / cargo-bench (push) Blocked by required conditions
Tests / tests-done (push) Blocked by required conditions

Closes: #18436

Implements:
https://github.com/matrix-org/matrix-spec-proposals/pull/4308

Follows: #18674

Adds an extension to Sliding Sync and a companion
endpoint needed for backpaginating missed thread subscription changes,
as described in MSC4308

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
reivilibre 2025-09-11 14:45:04 +01:00 committed by GitHub
parent 9cc4001778
commit ada3a3b2b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1019 additions and 63 deletions

View File

@ -0,0 +1 @@
Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled.

View File

@ -590,5 +590,5 @@ class ExperimentalConfig(Config):
self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False)
# MSC4306: Thread Subscriptions
# (and MSC4308: sliding sync extension for thread subscriptions)
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

View File

@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
limit = parse_integer_from_args(query, "limit", 0)
limit: Optional[int] = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", default=False

View File

@ -211,7 +211,7 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from. Token of the end of the
previous batch. May be `None` if this is the initial sync request.
"""

View File

@ -27,7 +27,7 @@ from typing import (
cast,
)
from typing_extensions import assert_never
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.handlers.receipts import ReceiptEventSource
@ -40,6 +40,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import (
HaveSentRoomFlag,
@ -54,6 +55,13 @@ from synapse.util.async_helpers import (
gather_optional_coroutines,
)
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -68,6 +76,7 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
@trace
async def get_extensions_response(
@ -93,7 +102,7 @@ class SlidingSyncExtensionHandler:
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the the
Sliding Sync response.
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@ -156,18 +165,32 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
thread_subs_coro = None
if (
sync_config.extensions.thread_subscriptions is not None
and self._enable_thread_subscriptions
):
thread_subs_coro = self.get_thread_subscriptions_extension_response(
sync_config=sync_config,
thread_subscriptions_request=sync_config.extensions.thread_subscriptions,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
thread_subs_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
thread_subs_coro,
)
return SlidingSyncResult.Extensions(
@ -176,6 +199,7 @@ class SlidingSyncExtensionHandler:
account_data=account_data_response,
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
)
def find_relevant_room_ids_for_extension(
@ -877,3 +901,72 @@ class SlidingSyncExtensionHandler:
return SlidingSyncResult.Extensions.TypingExtension(
room_id_to_typing_map=room_id_to_typing_map,
)
async def get_thread_subscriptions_extension_response(
self,
sync_config: SlidingSyncConfig,
thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension,
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]:
"""Handle Thread Subscriptions extension (MSC4308)
Args:
sync_config: Sync configuration
thread_subscriptions_request: The thread_subscriptions extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or thread subscriptions are disabled)
"""
if not thread_subscriptions_request.enabled:
return None
limit = thread_subscriptions_request.limit
if from_token:
from_stream_id = from_token.stream_token.thread_subscriptions_key
else:
from_stream_id = StreamToken.START.thread_subscriptions_key
to_stream_id = to_token.thread_subscriptions_key
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
user_id=sync_config.user.to_string(),
from_id=from_stream_id,
to_id=to_stream_id,
limit=limit,
)
if len(updates) == 0:
return None
subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {}
unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in updates:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadUnsubscription(bump_stamp=stream_id)
)
prev_batch = None
if len(updates) == limit:
# Tell the client about a potential gap where there may be more
# thread subscriptions for it to backpaginate.
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1)
return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension(
subscribed=subscribed_threads,
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)

View File

@ -9,7 +9,7 @@ from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
from synapse.types import EventOrderings, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -22,6 +22,7 @@ class ThreadSubscriptionsHandler:
self.store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._notifier = hs.get_notifier()
async def get_thread_subscription_settings(
self,
@ -132,6 +133,15 @@ class ThreadSubscriptionsHandler:
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome
async def unsubscribe_user_from_thread(
@ -162,8 +172,19 @@ class ThreadSubscriptionsHandler:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.unsubscribe_user_from_thread(
outcome = await self.store.unsubscribe_user_from_thread(
user_id.to_string(),
event.room_id,
thread_root_event_id,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome

View File

@ -130,6 +130,16 @@ def parse_integer(
return parse_integer_from_args(args, name, default, required, negative)
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: int,
required: Literal[False] = False,
negative: bool = False,
) -> int: ...
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],

View File

@ -532,6 +532,7 @@ class Notifier:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None,

View File

@ -44,6 +44,7 @@ from synapse.replication.tcp.streams import (
UnPartialStatedEventStream,
UnPartialStatedRoomStream,
)
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@ -255,6 +256,12 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
elif stream_name == ThreadSubscriptionsStream.NAME:
self.notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
token,
users=[row.user_id for row in rows],
)
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows

View File

@ -23,6 +23,8 @@ import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
import attr
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
@ -632,12 +634,21 @@ class SyncRestServlet(RestServlet):
class SlidingSyncRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived
from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a
subset (sliding window) of rooms, state, and timeline events (just what they need)
in order to bootstrap quickly and subscribe to only what the client cares about.
Because the client can specify what it cares about, we can respond quickly and skip
all of the work we would normally have to do with a sync v2 response.
Extensions of various features are defined in:
- to-device messaging (MSC3885)
- end-to-end encryption (MSC3884)
- typing notifications (MSC3961)
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
pos: Stream position token when asking for incremental deltas.
@ -1074,9 +1085,48 @@ class SlidingSyncRestServlet(RestServlet):
"rooms": extensions.typing.room_id_to_typing_map,
}
# excludes both None and falsy `thread_subscriptions`
if extensions.thread_subscriptions:
serialized_extensions["io.element.msc4308.thread_subscriptions"] = (
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
return serialized_extensions
def _serialise_thread_subscriptions(
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
) -> JsonDict:
out: JsonDict = {}
if thread_subscriptions.subscribed:
out["subscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.subscribed.items()
}
if thread_subscriptions.unsubscribed:
out["unsubscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.unsubscribed.items()
}
if thread_subscriptions.prev_batch:
out["prev_batch"] = thread_subscriptions.prev_batch.to_string()
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@ -1,21 +1,39 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import attr
from typing_extensions import TypeAlias
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types import (
JsonDict,
RoomID,
SlidingSyncStreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import SlidingSyncResult
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
class ThreadSubscriptionsRestServlet(RestServlet):
PATTERNS = client_patterns(
@ -100,6 +118,130 @@ class ThreadSubscriptionsRestServlet(RestServlet):
return HTTPStatus.OK, {}
class ThreadSubscriptionsPaginationRestServlet(RestServlet):
PATTERNS = client_patterns(
"/io.element.msc4308/thread_subscriptions$",
unstable=True,
releases=(),
)
CATEGORY = "Thread Subscriptions requests (unstable)"
# Maximum number of thread subscriptions to return in one request.
MAX_LIMIT = 512
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = min(
parse_integer(request, "limit", default=100, negative=False),
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
)
from_end_opt = parse_string(request, "from", required=False)
to_start_opt = parse_string(request, "to", required=False)
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))
if limit <= 0:
# condition needed because `negative=False` still allows 0
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"limit must be greater than 0",
errcode=Codes.INVALID_PARAM,
)
if from_end_opt is not None:
try:
# because of backwards pagination, the `from` token is actually the
# bound closest to the end of the stream
end_stream_id = ThreadSubscriptionsToken.from_string(
from_end_opt
).stream_id
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`from` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()
if to_start_opt is not None:
# because of backwards pagination, the `to` token is actually the
# bound closest to the start of the stream
try:
start_stream_id = ThreadSubscriptionsToken.from_string(
to_start_opt
).stream_id
except ValueError:
# we also accept sliding sync `pos` tokens on this parameter
try:
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
self.store, to_start_opt
)
start_stream_id = (
sliding_sync_pos.stream_token.thread_subscriptions_key
)
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`to` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
# the start of time is ID 1; the lower bound is exclusive though
start_stream_id = 0
subscriptions = (
await self.store.get_latest_updated_thread_subscriptions_for_user(
requester.user.to_string(),
from_id=start_stream_id,
to_id=end_stream_id,
limit=limit,
)
)
subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
)
result: JsonDict = {}
if subscribed_threads:
result["subscribed"] = subscribed_threads
if unsubscribed_threads:
result["unsubscribed"] = unsubscribed_threads
if len(subscriptions) == limit:
# We hit the limit, so there might be more entries to return.
# Generate a new token that has moved backwards, ready for the next
# request.
min_returned_stream_id, _, _, _, _ = subscriptions[0]
result["end"] = ThreadSubscriptionsToken(
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
stream_id=min_returned_stream_id - 1
).to_string()
return HTTPStatus.OK, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc4306_enabled:
ThreadSubscriptionsRestServlet(hs).register(http_server)
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)

View File

@ -53,7 +53,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@ -316,17 +316,8 @@ class RelationsWorkerStore(SQLBaseStore):
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
next_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, next_key
)
return events[:limit], next_token

View File

@ -492,7 +492,7 @@ class PerConnectionStateDB:
"""An equivalent to `PerConnectionState` that holds data in a format stored
in the DB.
The principle difference is that the tokens for the different streams are
The principal difference is that the tokens for the different streams are
serialized to strings.
When persisting this *only* contains updates to the state.

View File

@ -505,6 +505,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"""
return self._thread_subscriptions_id_gen.get_current_token()
def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._thread_subscriptions_id_gen
async def get_updated_thread_subscriptions(
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
@ -538,34 +541,52 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
get_updated_thread_subscriptions_txn,
)
async def get_updated_thread_subscriptions_for_user(
async def get_latest_updated_thread_subscriptions_for_user(
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
"""Get the latest updates to thread subscriptions for a specific user.
Args:
user_id: The ID of the user
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
If there are too many rows to return, rows from the start (closer to `from_id`)
will be omitted.
Returns:
A list of (stream_id, room_id, thread_root_event_id) tuples.
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
The row with lowest `stream_id` is the first row.
"""
def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
sql = """
SELECT stream_id, room_id, event_id
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
WITH the_updates AS (
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT ?
)
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM the_updates
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, from_id, to_id, limit))
return [(row[0], row[1], row[2]) for row in txn]
return [
(
stream_id,
room_id,
event_id,
# SQLite integer to boolean conversions
bool(subscribed),
bool(automatic) if subscribed else None,
)
for (stream_id, room_id, event_id, subscribed, automatic) in txn
]
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions_for_user",

View File

@ -0,0 +1,19 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the
-- stream sequence.
-- This makes last_value of the sequence point to a position that will not get later
-- returned by nextval.
-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.)
SELECT nextval('thread_subscriptions_sequence');

View File

@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
to have been 'seen as persisted'.
Unclear if this extant behaviour is desirable for some reason.
When creating a new sequence for a new stream,
it will be necessary to use `START WITH 2`.
When creating a new sequence for a new stream, it will be necessary to advance it
so that position 1 is consumed.
DO NOT USE `START WITH 2` FOR THIS PURPOSE:
see https://github.com/element-hq/synapse/issues/18712
Instead, use `SELECT nextval('sequence_name');` immediately after the
`CREATE SEQUENCE` statement.
Args:
db_conn

View File

@ -33,7 +33,6 @@ from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
StreamKeyType,
StreamToken,
)
@ -84,6 +83,7 @@ class EventSources:
un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token(
self._instance_name
)
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@ -97,6 +97,7 @@ class EventSources:
# Groups key is unused.
groups_key=0,
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
thread_subscriptions_key=thread_subscriptions_key,
)
return token
@ -123,6 +124,7 @@ class EventSources:
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
}
for _, key in StreamKeyType.__members__.items():
@ -195,16 +197,7 @@ class EventSources:
Returns:
The current token for pagination.
"""
token = StreamToken(
room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
return StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
await self.sources.room.get_current_key_for_room(room_id),
)
return token

View File

@ -996,6 +996,7 @@ class StreamKeyType(Enum):
TO_DEVICE = "to_device_key"
DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -1003,7 +1004,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379`
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@ -1016,6 +1017,7 @@ class StreamToken:
8. `device_list_key`: `265584`
9. `groups_key`: `1` (note that this key is now unused)
10. `un_partial_stated_rooms_key`: `379`
11. `thread_subscriptions_key`: 4242
You can see how many of these keys correspond to the various
fields in a "/sync" response:
@ -1074,6 +1076,7 @@ class StreamToken:
# Note that the groups key is no longer used and may have bogus values.
groups_key: int
un_partial_stated_rooms_key: int
thread_subscriptions_key: int
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@ -1101,6 +1104,7 @@ class StreamToken:
device_list_key,
groups_key,
un_partial_stated_rooms_key,
thread_subscriptions_key,
) = keys
return cls(
@ -1116,6 +1120,7 @@ class StreamToken:
),
groups_key=int(groups_key),
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
thread_subscriptions_key=int(thread_subscriptions_key),
)
except CancelledError:
raise
@ -1138,6 +1143,7 @@ class StreamToken:
# if additional tokens are added.
str(self.groups_key),
str(self.un_partial_stated_rooms_key),
str(self.thread_subscriptions_key),
]
)
@ -1202,6 +1208,7 @@ class StreamToken:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
) -> int: ...
@ -1257,7 +1264,8 @@ class StreamToken:
f"typing: {self.typing_key}, receipt: {self.receipt_key}, "
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})"
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
f"thread_subscriptions: {self.thread_subscriptions_key})"
)
@ -1272,6 +1280,7 @@ StreamToken.START = StreamToken(
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
thread_subscriptions_key=0,
)
@ -1318,6 +1327,27 @@ class SlidingSyncStreamToken:
return f"{self.connection_position}/{stream_token_str}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsToken:
"""
Token for a position in the thread subscriptions stream.
Format: `ts<stream_id>`
"""
stream_id: int
@staticmethod
def from_string(s: str) -> "ThreadSubscriptionsToken":
if not s.startswith("ts"):
raise ValueError("thread subscription token must start with `ts`")
return ThreadSubscriptionsToken(stream_id=int(s[2:]))
def to_string(self) -> str:
return f"ts{self.stream_id}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""

View File

@ -50,6 +50,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
UserID,
)
from synapse.types.rest.client import SlidingSyncBody
@ -357,11 +358,50 @@ class SlidingSyncResult:
def __bool__(self) -> bool:
return bool(self.room_id_to_typing_map)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsExtension:
"""The Thread Subscriptions extension (MSC4308)
Attributes:
subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions
unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions
prev_batch: if present, there is a gap and the client can use this token to backpaginate
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscription:
# always present when `subscribed`
automatic: Optional[bool]
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUnsubscription:
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
# room_id -> event_id (of thread root) -> the subscription change
subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]]
# room_id -> event_id (of thread root) -> the unsubscription
unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]]
prev_batch: Optional[ThreadSubscriptionsToken]
def __bool__(self) -> bool:
return (
bool(self.subscribed)
or bool(self.unsubscribed)
or bool(self.prev_batch)
)
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
def __bool__(self) -> bool:
return bool(
@ -370,6 +410,7 @@ class SlidingSyncResult:
or self.account_data
or self.receipts
or self.typing
or self.thread_subscriptions
)
next_pos: SlidingSyncStreamToken

View File

@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from synapse._pydantic_compat import (
Extra,
Field,
StrictBool,
StrictInt,
StrictStr,
@ -364,11 +365,25 @@ class SlidingSyncBody(RequestBodyModel):
# Process all room subscriptions defined in the Room Subscription API. (This is the default.)
rooms: Optional[List[StrictStr]] = ["*"]
class ThreadSubscriptionsExtension(RequestBodyModel):
"""The Thread Subscriptions extension (MSC4308)
Attributes:
enabled
limit: maximum number of subscription changes to return (default 100)
"""
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
alias="io.element.msc4308.thread_subscriptions"
)
conn_id: Optional[StrictStr]

View File

@ -347,6 +347,7 @@ T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
@overload
@ -461,6 +462,23 @@ async def gather_optional_coroutines(
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
Optional[Coroutine[Any, Any, T6]],
]
],
) -> Tuple[
Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6]
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]:

View File

@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_topo_token_is_accepted(self) -> None:
"""Test Topo Token is accepted."""
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
"""Test that stream token is accepted for forward pagination."""
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),

View File

@ -0,0 +1,497 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from http import HTTPStatus
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, room, sync, thread_subscriptions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
# The name of the extension. Currently unstable-prefixed.
EXT_NAME = "io.element.msc4308.thread_subscriptions"
class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase):
"""
Test the thread subscriptions extension in the Sliding Sync API.
"""
maxDiff = None
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
thread_subscriptions.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4306_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
super().prepare(reactor, clock, hs)
def test_no_data_initial_sync(self) -> None:
"""
Test enabling thread subscriptions extension during initial sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertNotIn(EXT_NAME, response_body["extensions"])
def test_no_data_incremental_sync(self) -> None:
"""
Test enabling thread subscriptions extension during incremental sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
initial_sync_body: JsonDict = {
"lists": {},
}
# Initial sync
response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok)
# Incremental sync with extension enabled
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert
self.assertNotIn(
EXT_NAME,
response_body["extensions"],
response_body,
)
def test_thread_subscription_initial_sync(self) -> None:
"""
Test thread subscriptions appear in initial sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_thread_subscription_incremental_sync(self) -> None:
"""
Test new thread subscriptions appear in incremental sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Initial sync
_, sync_pos = self.do_sync(sync_body, tok=user1_tok)
logger.info("Synced to: %r, now subscribing to thread", sync_pos)
# Subscribe
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
logger.info("Synced to: %r", sync_pos)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_unsubscribe_from_thread(self) -> None:
"""
Test unsubscribing from a thread.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok)
# Assert: Subscription present
self.assertIn(EXT_NAME, response_body["extensions"])
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {"automatic": False, "bump_stamp": base + 1}
}
}
},
)
# Unsubscribe
self._unsubscribe_from_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert: Unsubscription present
self.assertEqual(
response_body["extensions"][EXT_NAME],
{"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}},
)
def test_multiple_thread_subscriptions(self) -> None:
"""
Test handling of multiple thread subscriptions.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread roots
thread_root_resp1 = self.helper.send(
room_id, body="Thread root 1", tok=user1_tok
)
thread_root_id1 = thread_root_resp1["event_id"]
thread_root_resp2 = self.helper.send(
room_id, body="Thread root 2", tok=user1_tok
)
thread_root_id2 = thread_root_resp2["event_id"]
thread_root_resp3 = self.helper.send(
room_id, body="Thread root 3", tok=user1_tok
)
thread_root_id3 = thread_root_resp3["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Subscribe to threads
self._subscribe_to_thread(user1_id, room_id, thread_root_id1)
self._subscribe_to_thread(user1_id, room_id, thread_root_id2)
self._subscribe_to_thread(user1_id, room_id, thread_root_id3)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id1: {
"automatic": False,
"bump_stamp": base + 1,
},
thread_root_id2: {
"automatic": False,
"bump_stamp": base + 2,
},
thread_root_id3: {
"automatic": False,
"bump_stamp": base + 3,
},
}
}
},
)
def test_limit_parameter(self) -> None:
"""
Test limit parameter in thread subscriptions extension.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create 5 thread roots and subscribe to each
thread_root_ids = []
for i in range(5):
thread_root_resp = self.helper.send(
room_id, body=f"Thread root {i}", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 3}},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
thread_subscriptions = response_body["extensions"][EXT_NAME]
self.assertEqual(
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
)
def test_limit_and_companion_backpagination(self) -> None:
"""
Create 1 thread subscription, do a sync, create 4 more,
then sync with a limit of 2 and fill in the gap
using the companion /thread_subscriptions endpoint.
"""
thread_root_ids: List[str] = []
def make_subscription() -> None:
thread_root_resp = self.helper.send(
room_id, body="Some thread root", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Make our first subscription
make_subscription()
# Sync for the first time
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
}
sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
}
},
)
# Get our pos for the next sync
first_sync_pos = sync_resp["pos"]
# Create 5 more thread subscriptions and subscribe to each
for _ in range(5):
make_subscription()
# Now sync again. Our limit is 2,
# so we should get the latest 2 subscriptions,
# with a gap of 3 more subscriptions in the middle
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
}
},
)
# 1st backpagination: expecting a page with 2 subscriptions
page, end_tok = self._do_backpaginate(
from_tok=thread_subscriptions["prev_batch"],
to_tok=first_sync_pos,
limit=2,
access_token=user1_tok,
)
self.assertIsNotNone(end_tok, "backpagination should continue")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
}
},
)
# 2nd backpagination: expecting a page with only 1 subscription
# and no other token for further backpagination
assert end_tok is not None
page, end_tok = self._do_backpaginate(
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
)
self.assertIsNone(end_tok, "backpagination should have finished")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
}
},
)
def _do_backpaginate(
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
) -> Tuple[JsonDict, Optional[str]]:
channel = self.make_request(
"GET",
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
access_token=access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
body = channel.json_body
return body, cast(Optional[str], body.get("end"))
def _subscribe_to_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to subscribe a user to a thread.
"""
self.get_success(
self.store.subscribe_user_to_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
automatic_event_orderings=None,
)
)
def _unsubscribe_from_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to unsubscribe a user from a thread.
"""
self.get_success(
self.store.unsubscribe_user_from_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
)
)

View File

@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)

View File

@ -189,19 +189,19 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
limit=50,
)
)
min_id = min(id for (id, _, _) in subscriptions)
min_id = min(id for (id, _, _, _, _) in subscriptions)
self.assertEqual(
subscriptions,
[
(min_id, self.room_id, self.thread_root_id),
(min_id + 1, self.room_id, self.other_thread_root_id),
(min_id, self.room_id, self.thread_root_id, True, True),
(min_id + 1, self.room_id, self.other_thread_root_id, True, False),
],
)
@ -212,7 +212,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Check user has no subscriptions
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
@ -280,20 +280,22 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
self.assertEqual(
updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)]
)
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None: