280 lines
7.3 KiB
TypeScript
280 lines
7.3 KiB
TypeScript
import { Injectable, Inject } from '@nestjs/common';
|
||
import { EmbeddingService } from '../../infrastructure/embedding/embedding.service';
|
||
import {
|
||
IKnowledgeRepository,
|
||
KNOWLEDGE_REPOSITORY,
|
||
} from '../../domain/repositories/knowledge.repository.interface';
|
||
import {
|
||
IUserMemoryRepository,
|
||
ISystemExperienceRepository,
|
||
USER_MEMORY_REPOSITORY,
|
||
SYSTEM_EXPERIENCE_REPOSITORY,
|
||
} from '../../domain/repositories/memory.repository.interface';
|
||
import { KnowledgeArticleEntity } from '../../domain/entities/knowledge-article.entity';
|
||
import { KnowledgeChunkEntity } from '../../domain/entities/knowledge-chunk.entity';
|
||
|
||
/**
|
||
* RAG检索结果
|
||
*/
|
||
export interface RAGResult {
|
||
/** 检索到的知识内容 */
|
||
content: string;
|
||
/** 来源(用于引用) */
|
||
sources: Array<{
|
||
articleId: string;
|
||
title: string;
|
||
similarity: number;
|
||
}>;
|
||
/** 用户相关记忆 */
|
||
userMemories?: string[];
|
||
/** 系统经验 */
|
||
systemExperiences?: string[];
|
||
}
|
||
|
||
/**
|
||
* RAG服务 - 检索增强生成
|
||
* 负责从知识库中检索相关内容,增强AI回答
|
||
*/
|
||
@Injectable()
|
||
export class RAGService {
|
||
constructor(
|
||
private embeddingService: EmbeddingService,
|
||
@Inject(KNOWLEDGE_REPOSITORY)
|
||
private knowledgeRepo: IKnowledgeRepository,
|
||
@Inject(USER_MEMORY_REPOSITORY)
|
||
private memoryRepo: IUserMemoryRepository,
|
||
@Inject(SYSTEM_EXPERIENCE_REPOSITORY)
|
||
private experienceRepo: ISystemExperienceRepository,
|
||
) {}
|
||
|
||
/**
|
||
* 检索相关知识
|
||
*/
|
||
async retrieve(params: {
|
||
query: string;
|
||
userId?: string;
|
||
category?: string;
|
||
includeMemories?: boolean;
|
||
includeExperiences?: boolean;
|
||
topK?: number;
|
||
}): Promise<RAGResult> {
|
||
const {
|
||
query,
|
||
userId,
|
||
category,
|
||
includeMemories = true,
|
||
includeExperiences = true,
|
||
topK = 5,
|
||
} = params;
|
||
|
||
// 1. 生成查询向量
|
||
const queryEmbedding = await this.embeddingService.getEmbedding(query);
|
||
|
||
// 2. 并行检索
|
||
const [chunkResults, memoryResults, experienceResults] = await Promise.all([
|
||
// 检索知识块
|
||
this.knowledgeRepo.searchChunksByVector(queryEmbedding, {
|
||
category,
|
||
limit: topK,
|
||
minSimilarity: 0.6,
|
||
}),
|
||
// 检索用户记忆
|
||
includeMemories && userId
|
||
? this.memoryRepo.searchByVector(userId, queryEmbedding, {
|
||
limit: 3,
|
||
minSimilarity: 0.7,
|
||
})
|
||
: Promise.resolve([]),
|
||
// 检索系统经验
|
||
includeExperiences
|
||
? this.experienceRepo.searchByVector(queryEmbedding, {
|
||
activeOnly: true,
|
||
limit: 3,
|
||
minSimilarity: 0.75,
|
||
})
|
||
: Promise.resolve([]),
|
||
]);
|
||
|
||
// 3. 获取完整文章信息(用于引用)
|
||
const articleIds = [...new Set(chunkResults.map(r => r.chunk.articleId))];
|
||
const articles = await Promise.all(
|
||
articleIds.map(id => this.knowledgeRepo.findArticleById(id)),
|
||
);
|
||
const articleMap = new Map(
|
||
articles.filter(Boolean).map(a => [a!.id, a!]),
|
||
);
|
||
|
||
// 4. 组装结果
|
||
const content = this.formatRetrievedContent(chunkResults.map(r => r.chunk), articleMap);
|
||
|
||
const sources = chunkResults.map(r => {
|
||
const article = articleMap.get(r.chunk.articleId);
|
||
return {
|
||
articleId: r.chunk.articleId,
|
||
title: article?.title || 'Unknown',
|
||
similarity: r.similarity,
|
||
};
|
||
});
|
||
|
||
const userMemories = memoryResults.map(r => r.memory.content);
|
||
const systemExperiences = experienceResults.map(r => r.experience.content);
|
||
|
||
// 5. 更新引用计数
|
||
this.updateCitationCounts(articleIds);
|
||
|
||
return {
|
||
content,
|
||
sources,
|
||
userMemories: userMemories.length > 0 ? userMemories : undefined,
|
||
systemExperiences: systemExperiences.length > 0 ? systemExperiences : undefined,
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 检索并格式化为提示词上下文
|
||
*/
|
||
async retrieveForPrompt(params: {
|
||
query: string;
|
||
userId?: string;
|
||
category?: string;
|
||
}): Promise<string> {
|
||
const result = await this.retrieve(params);
|
||
|
||
let context = '';
|
||
|
||
// 知识库内容
|
||
if (result.content) {
|
||
context += `## 相关知识\n${result.content}\n\n`;
|
||
}
|
||
|
||
// 用户记忆
|
||
if (result.userMemories?.length) {
|
||
context += `## 用户背景信息\n`;
|
||
result.userMemories.forEach((m, i) => {
|
||
context += `${i + 1}. ${m}\n`;
|
||
});
|
||
context += '\n';
|
||
}
|
||
|
||
// 系统经验
|
||
if (result.systemExperiences?.length) {
|
||
context += `## 参考经验\n`;
|
||
result.systemExperiences.forEach((e, i) => {
|
||
context += `${i + 1}. ${e}\n`;
|
||
});
|
||
context += '\n';
|
||
}
|
||
|
||
// 来源引用
|
||
if (result.sources.length > 0) {
|
||
context += `## 来源\n`;
|
||
result.sources.forEach((s, i) => {
|
||
context += `[${i + 1}] ${s.title}\n`;
|
||
});
|
||
}
|
||
|
||
return context;
|
||
}
|
||
|
||
/**
|
||
* 格式化检索到的内容
|
||
*/
|
||
private formatRetrievedContent(
|
||
chunks: KnowledgeChunkEntity[],
|
||
articleMap: Map<string, KnowledgeArticleEntity>,
|
||
): string {
|
||
if (chunks.length === 0) {
|
||
return '';
|
||
}
|
||
|
||
// 按文章分组
|
||
const groupedByArticle = new Map<string, KnowledgeChunkEntity[]>();
|
||
chunks.forEach(chunk => {
|
||
const existing = groupedByArticle.get(chunk.articleId) || [];
|
||
existing.push(chunk);
|
||
groupedByArticle.set(chunk.articleId, existing);
|
||
});
|
||
|
||
let content = '';
|
||
let articleIndex = 1;
|
||
|
||
groupedByArticle.forEach((articleChunks, articleId) => {
|
||
const article = articleMap.get(articleId);
|
||
if (!article) return;
|
||
|
||
content += `### [${articleIndex}] ${article.title}\n`;
|
||
|
||
// 按块序号排序
|
||
articleChunks.sort((a, b) => a.chunkIndex - b.chunkIndex);
|
||
|
||
articleChunks.forEach(chunk => {
|
||
content += `${chunk.content}\n\n`;
|
||
});
|
||
|
||
articleIndex++;
|
||
});
|
||
|
||
return content;
|
||
}
|
||
|
||
/**
|
||
* 异步更新引用计数
|
||
*/
|
||
private async updateCitationCounts(articleIds: string[]): Promise<void> {
|
||
// 异步执行,不阻塞检索
|
||
setImmediate(async () => {
|
||
for (const id of articleIds) {
|
||
try {
|
||
const article = await this.knowledgeRepo.findArticleById(id);
|
||
if (article) {
|
||
article.incrementCitation();
|
||
await this.knowledgeRepo.updateArticle(article);
|
||
}
|
||
} catch (error) {
|
||
console.error(`Failed to update citation count for article ${id}:`, error);
|
||
}
|
||
}
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 检查是否为离题问题
|
||
*/
|
||
async checkOffTopic(query: string): Promise<{
|
||
isOffTopic: boolean;
|
||
confidence: number;
|
||
reason?: string;
|
||
}> {
|
||
// 使用向量相似度检查是否与知识库相关
|
||
const queryEmbedding = await this.embeddingService.getEmbedding(query);
|
||
|
||
const results = await this.knowledgeRepo.searchChunksByVector(queryEmbedding, {
|
||
limit: 1,
|
||
minSimilarity: 0.3, // 使用较低阈值
|
||
});
|
||
|
||
if (results.length === 0) {
|
||
return {
|
||
isOffTopic: true,
|
||
confidence: 0.8,
|
||
reason: '问题与香港移民主题无关',
|
||
};
|
||
}
|
||
|
||
const maxSimilarity = results[0].similarity;
|
||
|
||
if (maxSimilarity < 0.5) {
|
||
return {
|
||
isOffTopic: true,
|
||
confidence: 0.9 - maxSimilarity,
|
||
reason: '问题与香港移民主题相关性较低',
|
||
};
|
||
}
|
||
|
||
return {
|
||
isOffTopic: false,
|
||
confidence: maxSimilarity,
|
||
};
|
||
}
|
||
}
|