Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@
},
},
"platform": [],
"event_bus_dedup_ttl_seconds": 0.5,
"platform_specific": {
# 平台特异配置:按平台分类,平台下按功能分组
"lark": {
Expand Down Expand Up @@ -288,6 +289,9 @@ class ChatProviderTemplate(TypedDict):
"secret": "",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
"dedup_message_id_ttl_seconds": 1800.0,
"dedup_content_key_ttl_seconds": 3.0,
"dedup_cleanup_interval_seconds": 1.0,
},
"QQ 官方机器人(Webhook)": {
"id": "default",
Expand Down Expand Up @@ -713,6 +717,21 @@ class ChatProviderTemplate(TypedDict):
"type": "bool",
"hint": "启用后,机器人可以接收到频道的私聊消息。",
},
"dedup_message_id_ttl_seconds": {
"description": "消息 ID 去重窗口(秒)",
"type": "float",
"hint": "QQ 官方适配器中 message_id 去重窗口,默认 1800 秒。",
},
"dedup_content_key_ttl_seconds": {
"description": "内容键去重窗口(秒)",
"type": "float",
"hint": "QQ 官方适配器中 sender+content hash 去重窗口,默认 3 秒。",
},
"dedup_cleanup_interval_seconds": {
"description": "去重缓存清理间隔(秒)",
"type": "float",
"hint": "QQ 官方适配器去重缓存的增量清理间隔,默认 1 秒。",
},
"ws_reverse_host": {
"description": "反向 Websocket 主机",
"type": "string",
Expand Down
108 changes: 108 additions & 0 deletions astrbot/core/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,94 @@
"""

import asyncio
import hashlib
import math
import time
from asyncio import Queue
from collections import deque

from astrbot.core import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.message.components import File, Image
from astrbot.core.pipeline.scheduler import PipelineScheduler

from .platform import AstrMessageEvent


class EventDeduplicator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): 建议通过使用单一时间戳映射结构、复用浮点配置解析辅助函数,并将附件标识逻辑下沉到组件类中来简化去重设计和降低耦合度。

在不改变行为的前提下,当前的一些逻辑可以进一步简化:


1. TTL 结构:set + deque → 单一 dict

你目前使用 _seen_queue 来实现 TTL 机制。对于较小 TTL 且单消费循环的场景,用一个 dict[fingerprint, timestamp] 加一个简单清理循环,更易于理解,也能避免在队列中处理元组下标:

class EventDeduplicator:
    def __init__(self, ttl_seconds: float = 0.5) -> None:
        self._ttl_seconds = ttl_seconds
        self._seen: dict[tuple[str, ...], float] = {}

    def _clean_expired(self) -> None:
        now = time.monotonic()
        expire_before = now - self._ttl_seconds
        # iterate on a list() copy to allow deletion during iteration
        for fp, ts in list(self._seen.items()):
            if ts < expire_before:
                del self._seen[fp]

    def is_duplicate(self, event: AstrMessageEvent) -> bool:
        self._clean_expired()
        fingerprints = [self._build_content_fingerprint(event)]
        message_id_fingerprint = self._build_message_id_fingerprint(event)
        if message_id_fingerprint is not None:
            fingerprints.append(message_id_fingerprint)

        now = time.monotonic()
        for fp in fingerprints:
            if fp in self._seen:
                return True

        for fp in fingerprints:
            self._seen[fp] = now
        return False

这样可以在保持现有行为(相同的指纹、相同的 TTL 语义)的同时,去掉自定义队列索引的复杂度。


2. 复用配置解析辅助函数,而不是 _safe_float

EventBus._safe_float 与你在 qqofficial_platform_adapter.py 中的 _safe_float_config 存在逻辑重复。为降低跨文件的理解成本,可以考虑导入并使用共享的辅助函数(或将其移动到一个共享模块中):

# e.g. in a shared module, or wherever _safe_float_config lives
def safe_positive_float(value: int | float | str | None, default: float) -> float:
    try:
        parsed = float(value)
    except (TypeError, ValueError):
        return default
    return default if not math.isfinite(parsed) or parsed <= 0 else parsed

然后在 EventBus 中:

from astrbot.core.config_utils import safe_positive_float  # example location

class EventBus:
    def __init__(...):
        dedup_ttl_seconds = safe_positive_float(
            self.astrbot_config_mgr.g(
                None,
                "event_bus_dedup_ttl_seconds",
                0.5,
            ),
            default=0.5,
        )
        self._deduplicator = EventDeduplicator(ttl_seconds=dedup_ttl_seconds)

3. 将附件标识从字段级探测中解耦

_build_attachment_signature 当前对 Image/File 的内部结构做了假设。如果 Image/File 能暴露一个稳定标识符,可以把这些分支逻辑移动到这些类中,让去重器只关注标识符本身:

# In Image / File classes (or a base mixin)
@property
def stable_id(self) -> str:
    return self.url or self.file or self.file_unique or ""

# In EventDeduplicator
def _build_attachment_signature(self, event: AstrMessageEvent) -> str:
    ids: list[str] = []
    for component in event.get_messages():
        stable_id = getattr(component, "stable_id", None)
        if stable_id:
            ids.append(stable_id)

    if not ids:
        return ""

    payload = "|".join(ids)
    return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]

这样既可以保留指纹的丰富度,又能把与组件具体结构相关的逻辑放到数据自身所在类中,使 EventDeduplicator 对数据结构不那么敏感,更易于维护。

Original comment in English

issue (complexity): Consider simplifying the deduplication design by using a single timestamp map, sharing the float parsing helper, and pushing attachment identity logic into the component classes to reduce coupling.

A few parts of the new logic can be simplified without changing behavior:


1. TTL structure: set + deque → single dict

You’re maintaining both _seen and _queue to enforce TTL. For a small TTL and single-consumer loop, a single dict[fingerprint, timestamp] with a simple cleanup loop is easier to reason about and avoids tuple juggling:

class EventDeduplicator:
    def __init__(self, ttl_seconds: float = 0.5) -> None:
        self._ttl_seconds = ttl_seconds
        self._seen: dict[tuple[str, ...], float] = {}

    def _clean_expired(self) -> None:
        now = time.monotonic()
        expire_before = now - self._ttl_seconds
        # iterate on a list() copy to allow deletion during iteration
        for fp, ts in list(self._seen.items()):
            if ts < expire_before:
                del self._seen[fp]

    def is_duplicate(self, event: AstrMessageEvent) -> bool:
        self._clean_expired()
        fingerprints = [self._build_content_fingerprint(event)]
        message_id_fingerprint = self._build_message_id_fingerprint(event)
        if message_id_fingerprint is not None:
            fingerprints.append(message_id_fingerprint)

        now = time.monotonic()
        for fp in fingerprints:
            if fp in self._seen:
                return True

        for fp in fingerprints:
            self._seen[fp] = now
        return False

This preserves the current behavior (same fingerprints, same TTL semantics) but removes the custom queue indexing.


2. Reuse config parsing helper instead of _safe_float

EventBus._safe_float duplicates logic that you already have as _safe_float_config in qqofficial_platform_adapter.py. To reduce cross-file cognitive load, consider importing and using the shared helper (or moving it to a shared module):

# e.g. in a shared module, or wherever _safe_float_config lives
def safe_positive_float(value: int | float | str | None, default: float) -> float:
    try:
        parsed = float(value)
    except (TypeError, ValueError):
        return default
    return default if not math.isfinite(parsed) or parsed <= 0 else parsed

Then in EventBus:

from astrbot.core.config_utils import safe_positive_float  # example location

class EventBus:
    def __init__(...):
        dedup_ttl_seconds = safe_positive_float(
            self.astrbot_config_mgr.g(
                None,
                "event_bus_dedup_ttl_seconds",
                0.5,
            ),
            default=0.5,
        )
        self._deduplicator = EventDeduplicator(ttl_seconds=dedup_ttl_seconds)

3. Decouple attachment identity from field-level probing

_build_attachment_signature currently encodes assumptions about Image/File internals. If Image/File can expose a stable identifier, you can move this branching into those classes and keep the deduplicator focused on identifiers only:

# In Image / File classes (or a base mixin)
@property
def stable_id(self) -> str:
    return self.url or self.file or self.file_unique or ""

# In EventDeduplicator
def _build_attachment_signature(self, event: AstrMessageEvent) -> str:
    ids: list[str] = []
    for component in event.get_messages():
        stable_id = getattr(component, "stable_id", None)
        if stable_id:
            ids.append(stable_id)

    if not ids:
        return ""

    payload = "|".join(ids)
    return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]

This keeps the richness of the fingerprint but moves the component-specific logic to where the data lives, making EventDeduplicator less “schema-aware” and easier to maintain.

def __init__(self, ttl_seconds: float = 0.5) -> None:
self._ttl_seconds = ttl_seconds
self._seen: set[tuple[str, ...]] = set()
self._queue: deque[tuple[float, tuple[str, ...]]] = deque()

def _clean_expired(self) -> None:
now = time.monotonic()
expire_before = now - self._ttl_seconds
while self._queue and self._queue[0][0] < expire_before:
_, fingerprint = self._queue.popleft()
self._seen.discard(fingerprint)

def _build_attachment_signature(self, event: AstrMessageEvent) -> str:
signatures: list[str] = []
for component in event.get_messages():
if isinstance(component, Image):
image_ref = component.url or component.file or component.file_unique or ""
if image_ref:
signatures.append(f"img:{image_ref}")
elif isinstance(component, File):
file_ref = component.url or component.file_ or component.name or ""
if file_ref:
signatures.append(f"file:{file_ref}")

if not signatures:
return ""

payload = "|".join(signatures)
return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:16]

def _build_content_fingerprint(
self,
event: AstrMessageEvent,
) -> tuple[str, ...]:
attachment_signature = self._build_attachment_signature(event)
return (
"content",
event.get_platform_id() or "",
event.unified_msg_origin or "",
event.get_sender_id() or "",
(event.get_message_str() or "").strip(),
attachment_signature,
)

def _build_message_id_fingerprint(self, event: AstrMessageEvent) -> tuple[str, ...] | None:
message_id = str(getattr(event.message_obj, "message_id", "") or "")
if not message_id:
return None
return (
"message_id",
event.get_platform_id() or "",
event.unified_msg_origin or "",
message_id,
)

def is_duplicate(self, event: AstrMessageEvent) -> bool:
self._clean_expired()
fingerprints = [self._build_content_fingerprint(event)]
message_id_fingerprint = self._build_message_id_fingerprint(event)
if message_id_fingerprint is not None:
fingerprints.append(message_id_fingerprint)

for fingerprint in fingerprints:
if fingerprint in self._seen:
return True

ts = time.monotonic()
for fingerprint in fingerprints:
self._seen.add(fingerprint)
self._queue.append((ts, fingerprint))
return False


class EventBus:
"""用于处理事件的分发和处理"""

Expand All @@ -33,10 +112,39 @@ def __init__(
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
dedup_ttl_seconds = self._safe_float(
self.astrbot_config_mgr.g(
None,
"event_bus_dedup_ttl_seconds",
0.5,
),
default=0.5,
)
self._deduplicator = EventDeduplicator(ttl_seconds=dedup_ttl_seconds)

@staticmethod
def _safe_float(value: int | float | str | None, default: float) -> float:
if value is None:
return default
try:
parsed = float(value)
except (TypeError, ValueError):
return default
if not math.isfinite(parsed) or parsed <= 0:
return default
return parsed

async def dispatch(self) -> None:
# event_queue 由单一消费者处理;去重结构不是线程安全的,按设计仅在此循环中使用。
while True:
event: AstrMessageEvent = await self.event_queue.get()
if self._deduplicator.is_duplicate(event):
logger.debug(
"Skip duplicate event in event_bus, umo=%s, sender=%s",
event.unified_msg_origin,
event.get_sender_id(),
)
continue
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
conf_id = conf_info["id"]
conf_name = conf_info.get("name") or conf_id
Expand Down
9 changes: 9 additions & 0 deletions astrbot/core/platform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ async def load_platform(self, platform_config: dict) -> None:
)
return

# 防御式处理:避免同一平台 ID 被重复加载导致消息重复消费。
if platform_id in self._inst_map:
logger.warning(
"平台 %s(%s) 已存在实例,先终止旧实例再重载。",
platform_config["type"],
platform_id,
)
await self.terminate_platform(platform_id)
Comment on lines +126 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The load_platform function fails to correctly identify and terminate existing platform instances if the platform ID requires sanitization. When a platform ID contains illegal characters (like : or !), it is sanitized and the updated ID is stored in platform_config["id"]. However, the check for an existing instance on line 126 and the subsequent call to terminate_platform on line 132 use the original, unsanitized platform_id variable.

Since the platform instance was likely stored in self._inst_map using the sanitized ID during a previous load (see line 210), the check on line 126 will fail to find the existing instance. This results in multiple instances of the same platform running simultaneously, leading to duplicate message processing and resource leaks, which directly undermines the purpose of this PR.

            platform_id = platform_config.get("id")
            if not self._is_valid_platform_id(platform_id):
                sanitized_id, changed = self._sanitize_platform_id(platform_id)
                if sanitized_id and changed:
                    logger.warning(
                        "平台 ID %r 包含非法字符 ':' 或 '!',已替换为 %r。",
                        platform_id,
                        sanitized_id,
                    )
                    platform_config["id"] = sanitized_id
                    platform_id = sanitized_id # Update the local variable to the sanitized version
                    self.astrbot_config.save_config()
                else:
                    logger.error(
                        f"平台 ID {platform_id!r} 不能为空,跳过加载该平台适配器。",
                    )
                    return


logger.info(
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...",
)
Expand Down
Loading