167 lines
4.4 KiB
TypeScript
167 lines
4.4 KiB
TypeScript
import { Injectable, OnModuleInit } from '@nestjs/common';
|
||
import { ConfigService } from '@nestjs/config';
|
||
import OpenAI from 'openai';
|
||
|
||
/**
|
||
* 向量嵌入服务 - 使用OpenAI的text-embedding-3-small模型
|
||
* 用于将文本转换为向量,支持语义搜索
|
||
*/
|
||
@Injectable()
|
||
export class EmbeddingService implements OnModuleInit {
|
||
private openai: OpenAI;
|
||
private readonly modelName = 'text-embedding-3-small';
|
||
private readonly dimensions = 1536; // text-embedding-3-small 默认维度
|
||
|
||
constructor(private configService: ConfigService) {}
|
||
|
||
onModuleInit() {
|
||
const apiKey = this.configService.get<string>('OPENAI_API_KEY');
|
||
|
||
if (!apiKey) {
|
||
console.warn('[EmbeddingService] OPENAI_API_KEY not set, using mock embeddings');
|
||
return;
|
||
}
|
||
|
||
const baseURL = this.configService.get<string>('OPENAI_BASE_URL');
|
||
const isProxyUrl = baseURL && (baseURL.match(/^\d+\.\d+\.\d+\.\d+/) || baseURL.match(/^https?:\/\/\d+\.\d+\.\d+\.\d+/));
|
||
|
||
// If using IP-based proxy, disable TLS certificate verification
|
||
if (isProxyUrl) {
|
||
console.log(`[EmbeddingService] Using OpenAI proxy (TLS verification disabled): ${baseURL}`);
|
||
process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0';
|
||
}
|
||
|
||
this.openai = new OpenAI({
|
||
apiKey,
|
||
baseURL: baseURL || undefined,
|
||
});
|
||
|
||
if (baseURL && !isProxyUrl) {
|
||
console.log(`[EmbeddingService] Using OpenAI API base URL: ${baseURL}`);
|
||
} else if (!baseURL) {
|
||
console.log('[EmbeddingService] Initialized with OpenAI embedding model');
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取单个文本的向量
|
||
*/
|
||
async getEmbedding(text: string): Promise<number[]> {
|
||
if (!this.openai) {
|
||
return this.getMockEmbedding(text);
|
||
}
|
||
|
||
try {
|
||
const response = await this.openai.embeddings.create({
|
||
model: this.modelName,
|
||
input: this.preprocessText(text),
|
||
});
|
||
|
||
return response.data[0].embedding;
|
||
} catch (error) {
|
||
console.error('[EmbeddingService] Failed to get embedding:', error);
|
||
// 降级到mock embedding
|
||
return this.getMockEmbedding(text);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 批量获取向量
|
||
*/
|
||
async getEmbeddings(texts: string[]): Promise<number[][]> {
|
||
if (!this.openai) {
|
||
return texts.map(text => this.getMockEmbedding(text));
|
||
}
|
||
|
||
try {
|
||
const processedTexts = texts.map(text => this.preprocessText(text));
|
||
|
||
const response = await this.openai.embeddings.create({
|
||
model: this.modelName,
|
||
input: processedTexts,
|
||
});
|
||
|
||
// 按原始顺序返回
|
||
return response.data
|
||
.sort((a, b) => a.index - b.index)
|
||
.map(item => item.embedding);
|
||
} catch (error) {
|
||
console.error('[EmbeddingService] Failed to get batch embeddings:', error);
|
||
return texts.map(text => this.getMockEmbedding(text));
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 计算两个向量的余弦相似度
|
||
*/
|
||
cosineSimilarity(a: number[], b: number[]): number {
|
||
if (a.length !== b.length) {
|
||
throw new Error('Vectors must have the same length');
|
||
}
|
||
|
||
let dotProduct = 0;
|
||
let normA = 0;
|
||
let normB = 0;
|
||
|
||
for (let i = 0; i < a.length; i++) {
|
||
dotProduct += a[i] * b[i];
|
||
normA += a[i] * a[i];
|
||
normB += b[i] * b[i];
|
||
}
|
||
|
||
if (normA === 0 || normB === 0) {
|
||
return 0;
|
||
}
|
||
|
||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||
}
|
||
|
||
/**
|
||
* 获取向量维度
|
||
*/
|
||
getDimensions(): number {
|
||
return this.dimensions;
|
||
}
|
||
|
||
/**
|
||
* 预处理文本
|
||
*/
|
||
private preprocessText(text: string): string {
|
||
return text
|
||
.replace(/\s+/g, ' ') // 合并多个空白
|
||
.trim()
|
||
.substring(0, 8000); // OpenAI限制
|
||
}
|
||
|
||
/**
|
||
* 生成Mock向量(用于开发测试)
|
||
* 基于文本hash生成确定性的伪随机向量
|
||
*/
|
||
private getMockEmbedding(text: string): number[] {
|
||
const embedding: number[] = [];
|
||
let hash = 0;
|
||
|
||
// 简单hash
|
||
for (let i = 0; i < text.length; i++) {
|
||
const char = text.charCodeAt(i);
|
||
hash = ((hash << 5) - hash) + char;
|
||
hash = hash & hash;
|
||
}
|
||
|
||
// 基于hash生成伪随机向量
|
||
const random = (seed: number) => {
|
||
const x = Math.sin(seed) * 10000;
|
||
return x - Math.floor(x);
|
||
};
|
||
|
||
for (let i = 0; i < this.dimensions; i++) {
|
||
const value = random(hash + i) * 2 - 1; // -1 to 1
|
||
embedding.push(value);
|
||
}
|
||
|
||
// 归一化
|
||
const norm = Math.sqrt(embedding.reduce((sum, v) => sum + v * v, 0));
|
||
return embedding.map(v => v / norm);
|
||
}
|
||
}
|