iconsulting/packages/services/knowledge-service/src/infrastructure/embedding/embedding.service.ts

167 lines
4.4 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { Injectable, OnModuleInit } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import OpenAI from 'openai';
/**
* 向量嵌入服务 - 使用OpenAI的text-embedding-3-small模型
* 用于将文本转换为向量,支持语义搜索
*/
@Injectable()
export class EmbeddingService implements OnModuleInit {
private openai: OpenAI;
private readonly modelName = 'text-embedding-3-small';
private readonly dimensions = 1536; // text-embedding-3-small 默认维度
constructor(private configService: ConfigService) {}
onModuleInit() {
const apiKey = this.configService.get<string>('OPENAI_API_KEY');
if (!apiKey) {
console.warn('[EmbeddingService] OPENAI_API_KEY not set, using mock embeddings');
return;
}
const baseURL = this.configService.get<string>('OPENAI_BASE_URL');
const isProxyUrl = baseURL && (baseURL.match(/^\d+\.\d+\.\d+\.\d+/) || baseURL.match(/^https?:\/\/\d+\.\d+\.\d+\.\d+/));
// If using IP-based proxy, disable TLS certificate verification
if (isProxyUrl) {
console.log(`[EmbeddingService] Using OpenAI proxy (TLS verification disabled): ${baseURL}`);
process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0';
}
this.openai = new OpenAI({
apiKey,
baseURL: baseURL || undefined,
});
if (baseURL && !isProxyUrl) {
console.log(`[EmbeddingService] Using OpenAI API base URL: ${baseURL}`);
} else if (!baseURL) {
console.log('[EmbeddingService] Initialized with OpenAI embedding model');
}
}
/**
* 获取单个文本的向量
*/
async getEmbedding(text: string): Promise<number[]> {
if (!this.openai) {
return this.getMockEmbedding(text);
}
try {
const response = await this.openai.embeddings.create({
model: this.modelName,
input: this.preprocessText(text),
});
return response.data[0].embedding;
} catch (error) {
console.error('[EmbeddingService] Failed to get embedding:', error);
// 降级到mock embedding
return this.getMockEmbedding(text);
}
}
/**
* 批量获取向量
*/
async getEmbeddings(texts: string[]): Promise<number[][]> {
if (!this.openai) {
return texts.map(text => this.getMockEmbedding(text));
}
try {
const processedTexts = texts.map(text => this.preprocessText(text));
const response = await this.openai.embeddings.create({
model: this.modelName,
input: processedTexts,
});
// 按原始顺序返回
return response.data
.sort((a, b) => a.index - b.index)
.map(item => item.embedding);
} catch (error) {
console.error('[EmbeddingService] Failed to get batch embeddings:', error);
return texts.map(text => this.getMockEmbedding(text));
}
}
/**
* 计算两个向量的余弦相似度
*/
cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error('Vectors must have the same length');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA === 0 || normB === 0) {
return 0;
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* 获取向量维度
*/
getDimensions(): number {
return this.dimensions;
}
/**
* 预处理文本
*/
private preprocessText(text: string): string {
return text
.replace(/\s+/g, ' ') // 合并多个空白
.trim()
.substring(0, 8000); // OpenAI限制
}
/**
* 生成Mock向量用于开发测试
* 基于文本hash生成确定性的伪随机向量
*/
private getMockEmbedding(text: string): number[] {
const embedding: number[] = [];
let hash = 0;
// 简单hash
for (let i = 0; i < text.length; i++) {
const char = text.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash;
}
// 基于hash生成伪随机向量
const random = (seed: number) => {
const x = Math.sin(seed) * 10000;
return x - Math.floor(x);
};
for (let i = 0; i < this.dimensions; i++) {
const value = random(hash + i) * 2 - 1; // -1 to 1
embedding.push(value);
}
// 归一化
const norm = Math.sqrt(embedding.reduce((sum, v) => sum + v * v, 0));
return embedding.map(v => v / norm);
}
}