diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 16d7e89e3..4531f3432 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -211,6 +211,7 @@ }, }, "platform": [], + "event_bus_dedup_ttl_seconds": 0.5, "platform_specific": { # 平台特异配置:按平台分类,平台下按功能分组 "lark": { @@ -291,6 +292,9 @@ class ChatProviderTemplate(TypedDict): "secret": "", "enable_group_c2c": True, "enable_guild_direct_message": True, + "dedup_message_id_ttl_seconds": 1800.0, + "dedup_content_key_ttl_seconds": 3.0, + "dedup_cleanup_interval_seconds": 1.0, }, "QQ 官方机器人(Webhook)": { "id": "default", @@ -722,6 +726,21 @@ class ChatProviderTemplate(TypedDict): "type": "bool", "hint": "启用后,机器人可以接收到频道的私聊消息。", }, + "dedup_message_id_ttl_seconds": { + "description": "消息 ID 去重窗口(秒)", + "type": "float", + "hint": "QQ 官方适配器中 message_id 去重窗口,默认 1800 秒。", + }, + "dedup_content_key_ttl_seconds": { + "description": "内容键去重窗口(秒)", + "type": "float", + "hint": "QQ 官方适配器中 sender+content hash 去重窗口,默认 3 秒。", + }, + "dedup_cleanup_interval_seconds": { + "description": "去重缓存清理间隔(秒)", + "type": "float", + "hint": "QQ 官方适配器去重缓存的增量清理间隔,默认 1 秒。", + }, "ws_reverse_host": { "description": "反向 Websocket 主机", "type": "string", diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 70b5f054e..0e3abe496 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -11,15 +11,114 @@ """ import asyncio +import hashlib from asyncio import Queue from astrbot.core import logger from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.message.utils import build_component_dedup_signature from astrbot.core.pipeline.scheduler import PipelineScheduler +from astrbot.core.utils.number_utils import safe_positive_float +from astrbot.core.utils.ttl_registry import TTLKeyRegistry from .platform import AstrMessageEvent +class EventDeduplicator: + """Event deduplicator using TTL-based registry. + + This class handles deduplication of events based on content fingerprint + and message ID, with configurable TTL window. + """ + + _MAX_RAW_TEXT_FINGERPRINT_LEN = 256 + + def __init__(self, ttl_seconds: float = 0.5) -> None: + self._registry = TTLKeyRegistry(ttl_seconds) + + def _build_attachment_signature(self, event: AstrMessageEvent) -> str: + """Build attachment signature for deduplication.""" + return build_component_dedup_signature(event.get_messages()) + + def _build_content_key(self, event: AstrMessageEvent) -> str: + """Build content-based deduplication key.""" + msg_text = (event.get_message_str() or "").strip() + if len(msg_text) <= self._MAX_RAW_TEXT_FINGERPRINT_LEN: + msg_sig = msg_text + else: + msg_hash = hashlib.sha1(msg_text.encode("utf-8")).hexdigest()[:16] + msg_sig = f"h:{len(msg_text)}:{msg_hash}" + + attach_sig = self._build_attachment_signature(event) + return "|".join([ + "content", + event.get_platform_id() or "", + event.unified_msg_origin or "", + event.get_sender_id() or "", + msg_sig, + attach_sig, + ]) + + def _build_message_id_key(self, event: AstrMessageEvent) -> str | None: + """Build message ID-based deduplication key. + + Falls back to message_obj.id if message_id is not available. + """ + # Try message_id first + message_id = str(getattr(event.message_obj, "message_id", "") or "") + # Fallback to id if message_id is not available + if not message_id: + message_id = str(getattr(event.message_obj, "id", "") or "") + if not message_id: + return None + return "|".join([ + "message_id", + event.get_platform_id() or "", + event.unified_msg_origin or "", + message_id, + ]) + + def is_duplicate(self, event: AstrMessageEvent) -> bool: + """Check if the event is a duplicate. + + Returns False immediately if TTL is 0 (deduplication disabled). + Short-circuits on message_id key to avoid expensive attachment signature computation. + """ + # TTL of 0 means deduplication is disabled + if self._registry.ttl_seconds == 0: + return False + + # Short-circuit: check message_id first (cheap) before computing full content key (expensive) + message_id_key = self._build_message_id_key(event) + if message_id_key is not None: + if self._registry.contains(message_id_key): + logger.debug( + "Skip duplicate event in event_bus (by message_id): umo=%s, sender=%s", + event.unified_msg_origin, + event.get_sender_id(), + ) + return True + # Register message_id key since we'll process the event + self._registry.add(message_id_key) + + # Only compute full content key if we get past message_id check + content_key = self._build_content_key(event) + if self._registry.contains(content_key): + logger.debug( + "Skip duplicate event in event_bus (by content): umo=%s, sender=%s", + event.unified_msg_origin, + event.get_sender_id(), + ) + # If content duplicate, also remove message_id to preserve existing behavior + if message_id_key is not None: + self._registry.discard(message_id_key) + return True + + # Register content key + self._registry.add(content_key) + return False + + class EventBus: """用于处理事件的分发和处理""" @@ -33,10 +132,27 @@ def __init__( # abconf uuid -> scheduler self.pipeline_scheduler_mapping = pipeline_scheduler_mapping self.astrbot_config_mgr = astrbot_config_mgr + dedup_ttl_seconds = safe_positive_float( + self.astrbot_config_mgr.g( + None, + "event_bus_dedup_ttl_seconds", + 0.5, + ), + default=0.5, + ) + self._deduplicator = EventDeduplicator(ttl_seconds=dedup_ttl_seconds) async def dispatch(self) -> None: + # event_queue 由单一消费者处理;去重结构不是线程安全的,按设计仅在此循环中使用。 while True: event: AstrMessageEvent = await self.event_queue.get() + if self._deduplicator.is_duplicate(event): + logger.debug( + "Skip duplicate event in event_bus, umo=%s, sender=%s", + event.unified_msg_origin, + event.get_sender_id(), + ) + continue conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) conf_id = conf_info["id"] conf_name = conf_info.get("name") or conf_id diff --git a/astrbot/core/message/utils.py b/astrbot/core/message/utils.py new file mode 100644 index 000000000..91353b41e --- /dev/null +++ b/astrbot/core/message/utils.py @@ -0,0 +1,42 @@ +"""Message utilities for deduplication and component handling.""" + +import hashlib +from typing import Iterable + +from astrbot.core.message.components import BaseMessageComponent, File, Image + + +def build_component_dedup_signature( + components: Iterable[BaseMessageComponent], +) -> str: + """Build a deduplication signature from message components. + + This function extracts unique identifiers from Image and File components + and creates a hash-based signature for deduplication purposes. + + Args: + components: An iterable of message components to analyze. + + Returns: + A SHA1 hash (16 hex characters) representing the component signatures, + or an empty string if no valid components are found. + """ + parts: list[str] = [] + for component in components: + if isinstance(component, Image): + # Image can have url, file, or file_unique + ref = component.url or component.file or component.file_unique or "" + if ref: + parts.append(f"img:{ref}") + elif isinstance(component, File): + # File can have url, file (via property), or name + ref = component.url or component.file or component.name or "" + if ref: + parts.append(f"file:{ref}") + # Future component types can be added here + + if not parts: + return "" + + payload = "|".join(parts) + return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16] diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 68737b2bc..fb6887130 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -122,6 +122,15 @@ async def load_platform(self, platform_config: dict) -> None: ) return + # 防御式处理:避免同一平台 ID 被重复加载导致消息重复消费。 + if platform_id in self._inst_map: + logger.warning( + "平台 %s(%s) 已存在实例,先终止旧实例再重载。", + platform_config["type"], + platform_id, + ) + await self.terminate_platform(platform_id) + logger.info( f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...", ) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 436be70db..94cbf1d10 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -1,11 +1,11 @@ from __future__ import annotations import asyncio +import hashlib import logging import os import random import time -from types import SimpleNamespace from typing import Any, cast import botpy @@ -24,24 +24,137 @@ ) from astrbot.core.message.components import BaseMessageComponent from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.number_utils import safe_positive_float +from astrbot.core.utils.ttl_registry import TTLKeyRegistry from ...register import register_platform_adapter from .qqofficial_message_event import QQOfficialMessageEvent +# pyright: reportUnreachable=false + # remove logger handler for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # QQ 机器人官方框架 +def _extract_sender_id(message) -> str: + """Extract sender ID from a QQ message object. + + This is the central location for sender ID extraction logic to avoid + precedence drift between different code paths. + + The precedence order is: + 1. author.user_openid + 2. author.member_openid + 3. author.id + + Args: + message: The message object with an author attribute. + + Returns: + The sender ID as a string, or empty string if not found. + """ + if hasattr(message, "author") and hasattr(message.author, "user_openid"): + return str(message.author.user_openid) + if hasattr(message, "author") and hasattr(message.author, "member_openid"): + return str(message.author.member_openid) + if hasattr(message, "author") and hasattr(message.author, "id"): + return str(message.author.id) + return "" + + +class MessageDeduplicator: + def __init__( + self, + message_id_ttl_seconds: float = 30 * 60, + content_key_ttl_seconds: float = 3.0, + cleanup_interval_seconds: float = 1.0, + ) -> None: + self._message_ids = TTLKeyRegistry( + ttl_seconds=message_id_ttl_seconds, + cleanup_interval_seconds=cleanup_interval_seconds, + ) + self._content_keys = TTLKeyRegistry( + ttl_seconds=content_key_ttl_seconds, + cleanup_interval_seconds=cleanup_interval_seconds, + ) + self._lock = asyncio.Lock() + + def _build_content_key(self, content: str, sender_id: str) -> str | None: + if not (content and sender_id): + return None + content_hash = hashlib.sha1(content.encode("utf-8")).hexdigest()[:16] + return f"{sender_id}:{content_hash}" + + async def is_duplicate( + self, + message_id: str, + content: str = "", + sender_id: str = "", + ) -> bool: + async with self._lock: + # Bypass deduplication if TTL is 0 (disabled) + if self._message_ids.ttl_seconds == 0: + return False + + # 1) ID-based dedup + if self._message_ids.contains(message_id): + logger.debug( + "[QQOfficial] Duplicate message detected (by ID): %s...", + message_id[:50], + ) + return True + + self._message_ids.add(message_id) + + # 2) Content-based dedup + content_key = self._build_content_key(content, sender_id) + if content_key is None: + logger.debug("[QQOfficial] New message registered: %s...", message_id[:50]) + return False + + if self._content_keys.contains(content_key): + logger.debug( + "[QQOfficial] Duplicate message detected (by content): %s", + content_key, + ) + # Preserve existing behavior: do not keep message_id on content duplicates + self._message_ids.discard(message_id) + return True + + self._content_keys.add(content_key) + logger.debug("[QQOfficial] New message registered: %s...", message_id[:50]) + return False + + class botClient(Client): def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform + def _get_sender_id(self, message) -> str: + """Extract sender ID from different message types. + + Delegates to the centralized _extract_sender_id function to avoid + precedence drift. + """ + return _extract_sender_id(message) + + def _extract_dedup_key(self, message) -> tuple[str, str]: + sender_id = self._get_sender_id(message) + content = getattr(message, "content", "") or "" + return sender_id, content + + async def _should_drop_message(self, message) -> bool: + sender_id, content = self._extract_dedup_key(message) + return await self.platform._is_duplicate_message(message.id, content, sender_id) + # 收到群消息 async def on_group_at_message_create( self, message: botpy.message.GroupMessage ) -> None: + if await self._should_drop_message(message): + return abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -53,6 +166,8 @@ async def on_group_at_message_create( # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message) -> None: + if await self._should_drop_message(message): + return abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -66,6 +181,8 @@ async def on_at_message_create(self, message: botpy.message.Message) -> None: async def on_direct_message_create( self, message: botpy.message.DirectMessage ) -> None: + if await self._should_drop_message(message): + return abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -76,6 +193,8 @@ async def on_direct_message_create( # 收到 C2C 消息 async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: + if await self._should_drop_message(message): + return abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -136,6 +255,30 @@ def __init__( self.test_mode = os.environ.get("TEST_MODE", "off") == "on" + message_id_ttl_seconds = safe_positive_float( + platform_config.get("dedup_message_id_ttl_seconds"), + 30 * 60, + ) + content_key_ttl_seconds = safe_positive_float( + platform_config.get("dedup_content_key_ttl_seconds"), + 3.0, + ) + cleanup_interval_seconds = safe_positive_float( + platform_config.get("dedup_cleanup_interval_seconds"), + 1.0, + ) + + self._deduplicator = MessageDeduplicator( + message_id_ttl_seconds=message_id_ttl_seconds, + content_key_ttl_seconds=content_key_ttl_seconds, + cleanup_interval_seconds=cleanup_interval_seconds, + ) + + async def _is_duplicate_message( + self, message_id: str, content: str = "", sender_id: str = "" + ) -> bool: + return await self._deduplicator.is_duplicate(message_id, content, sender_id) + async def send_by_session( self, session: MessageSesion, @@ -177,7 +320,7 @@ async def _send_by_session_common( payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id} ret: Any = None - send_helper = SimpleNamespace(bot=self.client) + send_helper = self.client if session.message_type == MessageType.GROUP_MESSAGE: scene = self._session_scene.get(session.session_id) @@ -406,6 +549,7 @@ def _parse_from_qqofficial( abm.message_id = message.id # abm.tag = "qq_official" msg: list[BaseMessageComponent] = [] + message = cast(Any, message) if isinstance(message, botpy.message.GroupMessage) or isinstance( message, @@ -440,8 +584,9 @@ def _parse_from_qqofficial( QQOfficialPlatformAdapter._append_attachments(msg, message.attachments) abm.message = msg abm.message_str = plain_content + sender_user_id = _extract_sender_id(message) abm.sender = MessageMember( - str(message.author.id), + sender_user_id, str(message.author.username), ) msg.append(At(qq="qq_official")) diff --git a/astrbot/core/utils/number_utils.py b/astrbot/core/utils/number_utils.py new file mode 100644 index 000000000..d5e2e18de --- /dev/null +++ b/astrbot/core/utils/number_utils.py @@ -0,0 +1,26 @@ +import math + + +def safe_positive_float(value: object, default: float) -> float: + """Parse a value to a positive float. + + Args: + value: The value to parse (int, float, str, or other). + default: Default value to return if parsing fails or value is not positive. + + Returns: + The parsed positive float, or the default value. + Note: 0 is considered a valid value to allow disabling via config (e.g., TTL=0 disables dedup). + """ + if not isinstance(value, (int, float, str)): + return default + + try: + parsed = float(value) + except (TypeError, ValueError): + return default + + # Allow 0 to pass through (for disabling via config), but reject negative values + if not math.isfinite(parsed) or parsed < 0: + return default + return parsed diff --git a/astrbot/core/utils/ttl_registry.py b/astrbot/core/utils/ttl_registry.py new file mode 100644 index 000000000..5edf478e3 --- /dev/null +++ b/astrbot/core/utils/ttl_registry.py @@ -0,0 +1,134 @@ +"""TTL-based key registry for deduplication. + +This module provides a reusable TTL (time-to-live) key registry that can be used +for message/event deduplication across different components. +""" + +import time +from typing import Hashable, Sequence + + +class TTLKeyRegistry: + """A TTL-based registry for tracking seen keys. + + This utility handles time-based expiration of keys, making it suitable for + deduplication scenarios where old entries should be automatically cleaned up. + Supports optional cleanup interval throttling to avoid per-access full scans. + + Example: + registry = TTLKeyRegistry(ttl_seconds=0.5) + if registry.seen("some_key"): + # Key was seen within TTL window + pass + else: + # New key, register it + pass + """ + + def __init__( + self, + ttl_seconds: float, + cleanup_interval_seconds: float = 0.0, + ) -> None: + """Initialize the registry. + + Args: + ttl_seconds: Time-to-live in seconds for each key. Keys older than + this will be considered expired and cleaned up on next access. + cleanup_interval_seconds: Minimum interval between cleanup operations. + If 0 (default), cleanup runs on every access. + If > 0, cleanup is throttled to this interval. + """ + self._ttl_seconds = ttl_seconds + self._cleanup_interval_seconds = cleanup_interval_seconds + self._last_cleanup_at: float = 0.0 + self._seen: dict[Hashable, float] = {} + + @property + def ttl_seconds(self) -> float: + """Return the TTL seconds value.""" + return self._ttl_seconds + + def _clean_expired(self) -> None: + """Remove expired entries from the registry, with interval throttling.""" + # Short-circuit: if TTL is disabled (<=0), skip all cleanup logic + if self._ttl_seconds <= 0: + return + + now = time.monotonic() + + # Apply cleanup interval throttling if configured + if self._cleanup_interval_seconds > 0: + if self._last_cleanup_at > 0: + if now - self._last_cleanup_at < self._cleanup_interval_seconds: + return + self._last_cleanup_at = now + + expire_before = now - self._ttl_seconds + for key, ts in list(self._seen.items()): + if ts < expire_before: + del self._seen[key] + + def contains(self, key: Hashable) -> bool: + """Check if a key exists in the registry (without registering). + + Args: + key: The key to check. + + Returns: + True if the key exists and is not expired, False otherwise. + """ + self._clean_expired() + return key in self._seen + + def add(self, key: Hashable) -> None: + """Register a key with current timestamp. + + Args: + key: The key to add. + """ + self._seen[key] = time.monotonic() + + def discard(self, key: Hashable) -> None: + """Remove a key from the registry. + + Args: + key: The key to remove. + """ + self._seen.pop(key, None) + + def seen(self, key: Hashable) -> bool: + """Check if a key has been seen within the TTL window. + + If not seen, registers the key with current timestamp. + + Args: + key: The key to check. + + Returns: + True if the key was already seen within TTL window, False otherwise. + """ + self._clean_expired() + if key in self._seen: + return True + self._seen[key] = time.monotonic() + return False + + def seen_many(self, keys: Sequence[Hashable]) -> bool: + """Check if any of the keys have been seen within the TTL window. + + If none are seen, registers all keys with current timestamp. + + Args: + keys: The sequence of keys to check. + + Returns: + True if any key was already seen within TTL window, False otherwise. + """ + self._clean_expired() + now = time.monotonic() + if any(k in self._seen for k in keys): + return True + for k in keys: + self._seen[k] = now + return False