304 lines
8.4 KiB
TypeScript
304 lines
8.4 KiB
TypeScript
import { Injectable } from '@nestjs/common';
|
|
import { InjectRepository } from '@nestjs/typeorm';
|
|
import { Repository, Between, MoreThanOrEqual } from 'typeorm';
|
|
import { TokenUsageORM } from '../database/postgres/entities/token-usage.orm';
|
|
|
|
/**
|
|
* Claude API 定价 (截至 2024年)
|
|
* claude-sonnet-4-20250514:
|
|
* - Input: $3/MTok
|
|
* - Output: $15/MTok
|
|
* - Cache write: $3.75/MTok
|
|
* - Cache read: $0.30/MTok
|
|
*/
|
|
interface PricingTier {
|
|
input: number;
|
|
output: number;
|
|
cacheWrite: number;
|
|
cacheRead: number;
|
|
}
|
|
|
|
const PRICING: Record<string, PricingTier> = {
|
|
'claude-sonnet-4-20250514': {
|
|
input: 3 / 1_000_000,
|
|
output: 15 / 1_000_000,
|
|
cacheWrite: 3.75 / 1_000_000,
|
|
cacheRead: 0.30 / 1_000_000,
|
|
},
|
|
default: {
|
|
input: 3 / 1_000_000,
|
|
output: 15 / 1_000_000,
|
|
cacheWrite: 3.75 / 1_000_000,
|
|
cacheRead: 0.30 / 1_000_000,
|
|
},
|
|
};
|
|
|
|
export interface TokenUsageInput {
|
|
userId?: string;
|
|
conversationId: string;
|
|
messageId?: string;
|
|
model: string;
|
|
inputTokens: number;
|
|
outputTokens: number;
|
|
cacheCreationTokens?: number;
|
|
cacheReadTokens?: number;
|
|
intentType?: string;
|
|
toolCalls?: number;
|
|
responseLength?: number;
|
|
latencyMs?: number;
|
|
}
|
|
|
|
export interface UsageStats {
|
|
totalRequests: number;
|
|
totalInputTokens: number;
|
|
totalOutputTokens: number;
|
|
totalCacheReadTokens: number;
|
|
totalCacheCreationTokens: number;
|
|
totalTokens: number;
|
|
totalCost: number;
|
|
avgInputTokens: number;
|
|
avgOutputTokens: number;
|
|
avgLatencyMs: number;
|
|
cacheHitRate: number;
|
|
}
|
|
|
|
@Injectable()
|
|
export class TokenUsageService {
|
|
constructor(
|
|
@InjectRepository(TokenUsageORM)
|
|
private tokenUsageRepository: Repository<TokenUsageORM>,
|
|
) {}
|
|
|
|
/**
|
|
* 计算估算成本
|
|
*/
|
|
private calculateCost(
|
|
model: string,
|
|
inputTokens: number,
|
|
outputTokens: number,
|
|
cacheCreationTokens: number,
|
|
cacheReadTokens: number,
|
|
): number {
|
|
const pricing = PRICING[model] || PRICING.default;
|
|
|
|
// 缓存命中的 tokens 不计入普通输入
|
|
const regularInputTokens = inputTokens - cacheReadTokens;
|
|
|
|
return (
|
|
regularInputTokens * pricing.input +
|
|
outputTokens * pricing.output +
|
|
cacheCreationTokens * pricing.cacheWrite +
|
|
cacheReadTokens * pricing.cacheRead
|
|
);
|
|
}
|
|
|
|
/**
|
|
* 记录一次 API 调用的 token 使用量
|
|
*/
|
|
async recordUsage(input: TokenUsageInput): Promise<TokenUsageORM> {
|
|
const cacheCreationTokens = input.cacheCreationTokens || 0;
|
|
const cacheReadTokens = input.cacheReadTokens || 0;
|
|
const totalTokens = input.inputTokens + input.outputTokens;
|
|
|
|
const estimatedCost = this.calculateCost(
|
|
input.model,
|
|
input.inputTokens,
|
|
input.outputTokens,
|
|
cacheCreationTokens,
|
|
cacheReadTokens,
|
|
);
|
|
|
|
const entity = this.tokenUsageRepository.create({
|
|
userId: input.userId || null,
|
|
conversationId: input.conversationId,
|
|
messageId: input.messageId || null,
|
|
model: input.model,
|
|
inputTokens: input.inputTokens,
|
|
outputTokens: input.outputTokens,
|
|
cacheCreationTokens,
|
|
cacheReadTokens,
|
|
totalTokens,
|
|
estimatedCost,
|
|
intentType: input.intentType || null,
|
|
toolCalls: input.toolCalls || 0,
|
|
responseLength: input.responseLength || 0,
|
|
latencyMs: input.latencyMs || 0,
|
|
});
|
|
|
|
const saved = await this.tokenUsageRepository.save(entity);
|
|
|
|
console.log(
|
|
`[TokenUsage] Recorded: in=${input.inputTokens}, out=${input.outputTokens}, ` +
|
|
`cache_read=${cacheReadTokens}, cost=$${estimatedCost.toFixed(6)}, ` +
|
|
`intent=${input.intentType}, latency=${input.latencyMs}ms`
|
|
);
|
|
|
|
return saved;
|
|
}
|
|
|
|
/**
|
|
* 获取用户的 token 使用统计
|
|
*/
|
|
async getUserStats(userId: string, days: number = 30): Promise<UsageStats> {
|
|
const since = new Date();
|
|
since.setDate(since.getDate() - days);
|
|
|
|
const records = await this.tokenUsageRepository.find({
|
|
where: {
|
|
userId,
|
|
createdAt: MoreThanOrEqual(since),
|
|
},
|
|
});
|
|
|
|
return this.calculateStats(records);
|
|
}
|
|
|
|
/**
|
|
* 获取对话的 token 使用统计
|
|
*/
|
|
async getConversationStats(conversationId: string): Promise<UsageStats> {
|
|
const records = await this.tokenUsageRepository.find({
|
|
where: { conversationId },
|
|
});
|
|
|
|
return this.calculateStats(records);
|
|
}
|
|
|
|
/**
|
|
* 获取全局统计 (管理员用)
|
|
*/
|
|
async getGlobalStats(days: number = 30): Promise<UsageStats & { uniqueUsers: number }> {
|
|
const since = new Date();
|
|
since.setDate(since.getDate() - days);
|
|
|
|
const records = await this.tokenUsageRepository.find({
|
|
where: {
|
|
createdAt: MoreThanOrEqual(since),
|
|
},
|
|
});
|
|
|
|
const stats = this.calculateStats(records);
|
|
const uniqueUsers = new Set(records.filter(r => r.userId).map(r => r.userId)).size;
|
|
|
|
return { ...stats, uniqueUsers };
|
|
}
|
|
|
|
/**
|
|
* 获取日统计明细
|
|
*/
|
|
async getDailyStats(days: number = 7): Promise<Array<{
|
|
date: string;
|
|
requests: number;
|
|
totalTokens: number;
|
|
totalCost: number;
|
|
}>> {
|
|
const since = new Date();
|
|
since.setDate(since.getDate() - days);
|
|
|
|
const records = await this.tokenUsageRepository.find({
|
|
where: {
|
|
createdAt: MoreThanOrEqual(since),
|
|
},
|
|
order: { createdAt: 'ASC' },
|
|
});
|
|
|
|
// 按日期分组
|
|
const byDate = new Map<string, TokenUsageORM[]>();
|
|
for (const record of records) {
|
|
const date = record.createdAt.toISOString().split('T')[0];
|
|
if (!byDate.has(date)) {
|
|
byDate.set(date, []);
|
|
}
|
|
byDate.get(date)!.push(record);
|
|
}
|
|
|
|
return Array.from(byDate.entries()).map(([date, dayRecords]) => ({
|
|
date,
|
|
requests: dayRecords.length,
|
|
totalTokens: dayRecords.reduce((sum, r) => sum + r.totalTokens, 0),
|
|
totalCost: dayRecords.reduce((sum, r) => sum + Number(r.estimatedCost), 0),
|
|
}));
|
|
}
|
|
|
|
/**
|
|
* 获取用户排行榜 (按 token 消耗)
|
|
*/
|
|
async getTopUsers(days: number = 30, limit: number = 10): Promise<Array<{
|
|
userId: string;
|
|
totalTokens: number;
|
|
totalCost: number;
|
|
requestCount: number;
|
|
}>> {
|
|
const since = new Date();
|
|
since.setDate(since.getDate() - days);
|
|
|
|
const result = await this.tokenUsageRepository
|
|
.createQueryBuilder('usage')
|
|
.select('usage.user_id', 'userId')
|
|
.addSelect('SUM(usage.total_tokens)', 'totalTokens')
|
|
.addSelect('SUM(usage.estimated_cost)', 'totalCost')
|
|
.addSelect('COUNT(*)', 'requestCount')
|
|
.where('usage.created_at >= :since', { since })
|
|
.andWhere('usage.user_id IS NOT NULL')
|
|
.groupBy('usage.user_id')
|
|
.orderBy('SUM(usage.total_tokens)', 'DESC')
|
|
.limit(limit)
|
|
.getRawMany();
|
|
|
|
return result.map(r => ({
|
|
userId: r.userId,
|
|
totalTokens: parseInt(r.totalTokens) || 0,
|
|
totalCost: parseFloat(r.totalCost) || 0,
|
|
requestCount: parseInt(r.requestCount) || 0,
|
|
}));
|
|
}
|
|
|
|
/**
|
|
* 计算统计数据
|
|
*/
|
|
private calculateStats(records: TokenUsageORM[]): UsageStats {
|
|
if (records.length === 0) {
|
|
return {
|
|
totalRequests: 0,
|
|
totalInputTokens: 0,
|
|
totalOutputTokens: 0,
|
|
totalCacheReadTokens: 0,
|
|
totalCacheCreationTokens: 0,
|
|
totalTokens: 0,
|
|
totalCost: 0,
|
|
avgInputTokens: 0,
|
|
avgOutputTokens: 0,
|
|
avgLatencyMs: 0,
|
|
cacheHitRate: 0,
|
|
};
|
|
}
|
|
|
|
const totalInputTokens = records.reduce((sum, r) => sum + r.inputTokens, 0);
|
|
const totalOutputTokens = records.reduce((sum, r) => sum + r.outputTokens, 0);
|
|
const totalCacheReadTokens = records.reduce((sum, r) => sum + r.cacheReadTokens, 0);
|
|
const totalCacheCreationTokens = records.reduce((sum, r) => sum + r.cacheCreationTokens, 0);
|
|
const totalTokens = records.reduce((sum, r) => sum + r.totalTokens, 0);
|
|
const totalCost = records.reduce((sum, r) => sum + Number(r.estimatedCost), 0);
|
|
const totalLatency = records.reduce((sum, r) => sum + r.latencyMs, 0);
|
|
|
|
// 缓存命中率 = 缓存读取的 tokens / 总输入 tokens
|
|
const cacheHitRate = totalInputTokens > 0
|
|
? (totalCacheReadTokens / totalInputTokens) * 100
|
|
: 0;
|
|
|
|
return {
|
|
totalRequests: records.length,
|
|
totalInputTokens,
|
|
totalOutputTokens,
|
|
totalCacheReadTokens,
|
|
totalCacheCreationTokens,
|
|
totalTokens,
|
|
totalCost,
|
|
avgInputTokens: Math.round(totalInputTokens / records.length),
|
|
avgOutputTokens: Math.round(totalOutputTokens / records.length),
|
|
avgLatencyMs: Math.round(totalLatency / records.length),
|
|
cacheHitRate: Math.round(cacheHitRate * 100) / 100,
|
|
};
|
|
}
|
|
}
|