import { Injectable, Inject } from '@nestjs/common'; import { EmbeddingService } from '../../infrastructure/embedding/embedding.service'; import { IKnowledgeRepository, KNOWLEDGE_REPOSITORY, } from '../../domain/repositories/knowledge.repository.interface'; import { IUserMemoryRepository, ISystemExperienceRepository, USER_MEMORY_REPOSITORY, SYSTEM_EXPERIENCE_REPOSITORY, } from '../../domain/repositories/memory.repository.interface'; import { KnowledgeArticleEntity } from '../../domain/entities/knowledge-article.entity'; import { KnowledgeChunkEntity } from '../../domain/entities/knowledge-chunk.entity'; /** * RAG检索结果 */ export interface RAGResult { /** 检索到的知识内容 */ content: string; /** 来源(用于引用) */ sources: Array<{ articleId: string; title: string; similarity: number; }>; /** 用户相关记忆 */ userMemories?: string[]; /** 系统经验 */ systemExperiences?: string[]; } /** * RAG服务 - 检索增强生成 * 负责从知识库中检索相关内容,增强AI回答 */ @Injectable() export class RAGService { constructor( private embeddingService: EmbeddingService, @Inject(KNOWLEDGE_REPOSITORY) private knowledgeRepo: IKnowledgeRepository, @Inject(USER_MEMORY_REPOSITORY) private memoryRepo: IUserMemoryRepository, @Inject(SYSTEM_EXPERIENCE_REPOSITORY) private experienceRepo: ISystemExperienceRepository, ) {} /** * 检索相关知识 */ async retrieve(params: { query: string; userId?: string; category?: string; includeMemories?: boolean; includeExperiences?: boolean; topK?: number; }): Promise { const { query, userId, category, includeMemories = true, includeExperiences = true, topK = 5, } = params; // 1. 生成查询向量 const queryEmbedding = await this.embeddingService.getEmbedding(query); // 2. 并行检索 const [chunkResults, memoryResults, experienceResults] = await Promise.all([ // 检索知识块 this.knowledgeRepo.searchChunksByVector(queryEmbedding, { category, limit: topK, minSimilarity: 0.6, }), // 检索用户记忆 includeMemories && userId ? this.memoryRepo.searchByVector(userId, queryEmbedding, { limit: 3, minSimilarity: 0.7, }) : Promise.resolve([]), // 检索系统经验 includeExperiences ? this.experienceRepo.searchByVector(queryEmbedding, { activeOnly: true, limit: 3, minSimilarity: 0.75, }) : Promise.resolve([]), ]); // 3. 获取完整文章信息(用于引用) const articleIds = [...new Set(chunkResults.map(r => r.chunk.articleId))]; const articles = await Promise.all( articleIds.map(id => this.knowledgeRepo.findArticleById(id)), ); const articleMap = new Map( articles.filter(Boolean).map(a => [a!.id, a!]), ); // 4. 组装结果 const content = this.formatRetrievedContent(chunkResults.map(r => r.chunk), articleMap); const sources = chunkResults.map(r => { const article = articleMap.get(r.chunk.articleId); return { articleId: r.chunk.articleId, title: article?.title || 'Unknown', similarity: r.similarity, }; }); const userMemories = memoryResults.map(r => r.memory.content); const systemExperiences = experienceResults.map(r => r.experience.content); // 5. 更新引用计数 this.updateCitationCounts(articleIds); return { content, sources, userMemories: userMemories.length > 0 ? userMemories : undefined, systemExperiences: systemExperiences.length > 0 ? systemExperiences : undefined, }; } /** * 检索并格式化为提示词上下文 */ async retrieveForPrompt(params: { query: string; userId?: string; category?: string; }): Promise { const result = await this.retrieve(params); let context = ''; // 知识库内容 if (result.content) { context += `## 相关知识\n${result.content}\n\n`; } // 用户记忆 if (result.userMemories?.length) { context += `## 用户背景信息\n`; result.userMemories.forEach((m, i) => { context += `${i + 1}. ${m}\n`; }); context += '\n'; } // 系统经验 if (result.systemExperiences?.length) { context += `## 参考经验\n`; result.systemExperiences.forEach((e, i) => { context += `${i + 1}. ${e}\n`; }); context += '\n'; } // 来源引用 if (result.sources.length > 0) { context += `## 来源\n`; result.sources.forEach((s, i) => { context += `[${i + 1}] ${s.title}\n`; }); } return context; } /** * 格式化检索到的内容 */ private formatRetrievedContent( chunks: KnowledgeChunkEntity[], articleMap: Map, ): string { if (chunks.length === 0) { return ''; } // 按文章分组 const groupedByArticle = new Map(); chunks.forEach(chunk => { const existing = groupedByArticle.get(chunk.articleId) || []; existing.push(chunk); groupedByArticle.set(chunk.articleId, existing); }); let content = ''; let articleIndex = 1; groupedByArticle.forEach((articleChunks, articleId) => { const article = articleMap.get(articleId); if (!article) return; content += `### [${articleIndex}] ${article.title}\n`; // 按块序号排序 articleChunks.sort((a, b) => a.chunkIndex - b.chunkIndex); articleChunks.forEach(chunk => { content += `${chunk.content}\n\n`; }); articleIndex++; }); return content; } /** * 异步更新引用计数 */ private async updateCitationCounts(articleIds: string[]): Promise { // 异步执行,不阻塞检索 setImmediate(async () => { for (const id of articleIds) { try { const article = await this.knowledgeRepo.findArticleById(id); if (article) { article.incrementCitation(); await this.knowledgeRepo.updateArticle(article); } } catch (error) { console.error(`Failed to update citation count for article ${id}:`, error); } } }); } /** * 检查是否为离题问题 */ async checkOffTopic(query: string): Promise<{ isOffTopic: boolean; confidence: number; reason?: string; }> { // 使用向量相似度检查是否与知识库相关 const queryEmbedding = await this.embeddingService.getEmbedding(query); const results = await this.knowledgeRepo.searchChunksByVector(queryEmbedding, { limit: 1, minSimilarity: 0.3, // 使用较低阈值 }); if (results.length === 0) { return { isOffTopic: true, confidence: 0.8, reason: '问题与香港移民主题无关', }; } const maxSimilarity = results[0].similarity; if (maxSimilarity < 0.5) { return { isOffTopic: true, confidence: 0.9 - maxSimilarity, reason: '问题与香港移民主题相关性较低', }; } return { isOffTopic: false, confidence: maxSimilarity, }; } }