diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 39fc72ac85..4b9dcf7dd0 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -256,6 +256,46 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } + async def get_documents_with_metadata_batch( + self, doc_ids: set[str] + ) -> dict[str, dict]: + """批量获取文档及其所属知识库元数据 + + Args: + doc_ids: 文档 ID 集合 + + Returns: + dict: doc_id -> {"document": KBDocument, "knowledge_base": KnowledgeBase} + + """ + if not doc_ids: + return {} + + metadata_map: dict[str, dict] = {} + # SQLite 参数上限为 999,分片查询避免超限 + chunk_size = 900 + doc_id_list = list(doc_ids) + + async with self.get_db() as session: + for i in range(0, len(doc_id_list), chunk_size): + chunk = doc_id_list[i : i + chunk_size] + stmt = ( + select(KBDocument, KnowledgeBase) + .join( + KnowledgeBase, + col(KBDocument.kb_id) == col(KnowledgeBase.kb_id), + ) + .where(col(KBDocument.doc_id).in_(chunk)) + ) + result = await session.execute(stmt) + for row in result.all(): + metadata_map[row[0].doc_id] = { + "document": row[0], + "knowledge_base": row[1], + } + + return metadata_map + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index d406ceabce..1244e18af1 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -142,10 +142,13 @@ async def retrieve( f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.", ) - # 4. 转换为 RetrievalResult (获取元数据) + # 4. 转换为 RetrievalResult (批量获取元数据) + doc_ids = {fr.doc_id for fr in fused_results} + metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids) + retrieval_results = [] for fr in fused_results: - metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id) + metadata_dict = metadata_map.get(fr.doc_id) if metadata_dict: retrieval_results.append( RetrievalResult(